forked from OSchip/llvm-project
[MLIR] AffineMap value type
This CL applies the same pattern as AffineExpr to AffineMap: a simple struct that acts as the storage is allocated in the bump pointer. The AffineMap is immutable and accessed everywhere by value. PiperOrigin-RevId: 216445930
This commit is contained in:
parent
82e55750d2
commit
1d3e7e2616
|
@ -42,7 +42,7 @@ class HyperRectangularSet;
|
|||
/// A mutable affine map. Its affine expressions are however unique.
|
||||
struct MutableAffineMap {
|
||||
public:
|
||||
MutableAffineMap(AffineMap *map, MLIRContext *context);
|
||||
MutableAffineMap(AffineMap map);
|
||||
|
||||
AffineExpr getResult(unsigned idx) const { return results[idx]; }
|
||||
void setResult(unsigned idx, AffineExpr result) { results[idx] = result; }
|
||||
|
@ -60,9 +60,9 @@ public:
|
|||
//-simplify-affine-expr pass).
|
||||
void simplify();
|
||||
/// Get the AffineMap corresponding to this MutableAffineMap. Note that an
|
||||
/// AffineMap * will be uniqued and stored in context, while a mutable one
|
||||
/// AffineMap will be uniqued and stored in context, while a mutable one
|
||||
/// isn't.
|
||||
AffineMap *getAffineMap();
|
||||
AffineMap getAffineMap();
|
||||
|
||||
private:
|
||||
// Same meaning as AffineMap's fields.
|
||||
|
@ -117,11 +117,10 @@ private:
|
|||
// TODO(bondhugula): Some of these classes could go into separate files.
|
||||
class AffineValueMap {
|
||||
public:
|
||||
AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
|
||||
AffineValueMap(const AffineBound &bound, MLIRContext *context);
|
||||
AffineValueMap(AffineMap *map, MLIRContext *context);
|
||||
AffineValueMap(AffineMap *map, ArrayRef<MLValue *> operands,
|
||||
MLIRContext *context);
|
||||
AffineValueMap(const AffineApplyOp &op);
|
||||
AffineValueMap(const AffineBound &bound);
|
||||
AffineValueMap(AffineMap map);
|
||||
AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands);
|
||||
|
||||
~AffineValueMap();
|
||||
|
||||
|
@ -156,7 +155,7 @@ public:
|
|||
unsigned getNumOperands() const;
|
||||
SSAValue *getOperand(unsigned i) const;
|
||||
ArrayRef<MLValue *> getOperands() const;
|
||||
AffineMap *getAffineMap();
|
||||
AffineMap getAffineMap();
|
||||
|
||||
private:
|
||||
// A mutable affine map.
|
||||
|
|
|
@ -27,9 +27,16 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace detail {
|
||||
|
||||
class AffineMapStorage;
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
class AffineExpr;
|
||||
class Attribute;
|
||||
class MLIRContext;
|
||||
|
@ -41,71 +48,91 @@ class MLIRContext;
|
|||
/// is unique to this affine map.
|
||||
class AffineMap {
|
||||
public:
|
||||
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes);
|
||||
typedef detail::AffineMapStorage ImplType;
|
||||
|
||||
explicit AffineMap(ImplType *map = nullptr) : map(map) {}
|
||||
static AffineMap Invalid() { return AffineMap(nullptr); }
|
||||
|
||||
static AffineMap get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes);
|
||||
|
||||
/// Returns a single constant result affine map.
|
||||
static AffineMap *getConstantMap(int64_t val, MLIRContext *context);
|
||||
static AffineMap getConstantMap(int64_t val, MLIRContext *context);
|
||||
|
||||
explicit operator bool() { return map; }
|
||||
bool operator==(const AffineMap &other) const { return other.map == map; }
|
||||
|
||||
/// Returns true if the co-domain (or more loosely speaking, range) of this
|
||||
/// map is bounded. Bounded affine maps have a size (extent) for each of
|
||||
/// their range dimensions (more accurately co-domain dimensions).
|
||||
bool isBounded() { return !rangeSizes.empty(); }
|
||||
bool isBounded() const;
|
||||
|
||||
/// Returns true if this affine map is an identity affine map.
|
||||
/// An identity affine map corresponds to an identity affine function on the
|
||||
/// dimensional identifiers.
|
||||
bool isIdentity();
|
||||
bool isIdentity() const;
|
||||
|
||||
/// Returns true if this affine map is a single result constant function.
|
||||
bool isSingleConstant();
|
||||
bool isSingleConstant() const;
|
||||
|
||||
/// Returns the constant result of this map. This methods asserts that the map
|
||||
/// has a single constant result.
|
||||
int64_t getSingleConstantResult();
|
||||
int64_t getSingleConstantResult() const;
|
||||
|
||||
// Prints affine map to 'os'.
|
||||
void print(raw_ostream &os);
|
||||
void dump();
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
unsigned getNumDims() { return numDims; }
|
||||
unsigned getNumSymbols() { return numSymbols; }
|
||||
unsigned getNumResults() { return numResults; }
|
||||
unsigned getNumInputs() { return numDims + numSymbols; }
|
||||
unsigned getNumDims() const;
|
||||
unsigned getNumSymbols() const;
|
||||
unsigned getNumResults() const;
|
||||
unsigned getNumInputs() const;
|
||||
|
||||
ArrayRef<AffineExpr> getResults() { return results; }
|
||||
ArrayRef<AffineExpr> getResults() const;
|
||||
AffineExpr getResult(unsigned idx) const;
|
||||
|
||||
AffineExpr getResult(unsigned idx);
|
||||
|
||||
ArrayRef<AffineExpr> getRangeSizes() { return rangeSizes; }
|
||||
ArrayRef<AffineExpr> getRangeSizes() const;
|
||||
|
||||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible. Returns false if the folding happens,
|
||||
/// true otherwise.
|
||||
bool constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results);
|
||||
SmallVectorImpl<Attribute *> &results) const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(AffineMap arg);
|
||||
|
||||
private:
|
||||
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||
ArrayRef<AffineExpr> results, ArrayRef<AffineExpr> rangeSizes);
|
||||
|
||||
AffineMap(const AffineMap &) = delete;
|
||||
void operator=(const AffineMap &) = delete;
|
||||
|
||||
unsigned numDims;
|
||||
unsigned numSymbols;
|
||||
unsigned numResults;
|
||||
|
||||
/// The affine expressions for this (multi-dimensional) map.
|
||||
/// TODO: use trailing objects for this.
|
||||
ArrayRef<AffineExpr> results;
|
||||
|
||||
/// The extents along each of the range dimensions if the map is bounded,
|
||||
/// nullptr otherwise.
|
||||
ArrayRef<AffineExpr> rangeSizes;
|
||||
ImplType *map;
|
||||
};
|
||||
|
||||
// Make AffineExpr hashable.
|
||||
inline ::llvm::hash_code hash_value(AffineMap arg) {
|
||||
return ::llvm::hash_value(arg.map);
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// AffineExpr hash just like pointers
|
||||
template <> struct DenseMapInfo<mlir::AffineMap> {
|
||||
static mlir::AffineMap getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::AffineMap getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::AffineMap val) {
|
||||
return mlir::hash_value(val);
|
||||
}
|
||||
static bool isEqual(mlir::AffineMap LHS, mlir::AffineMap RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_IR_AFFINE_MAP_H
|
||||
|
|
|
@ -18,11 +18,11 @@
|
|||
#ifndef MLIR_IR_ATTRIBUTES_H
|
||||
#define MLIR_IR_ATTRIBUTES_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class Function;
|
||||
class FunctionType;
|
||||
class MLIRContext;
|
||||
|
@ -182,22 +182,21 @@ private:
|
|||
|
||||
class AffineMapAttr : public Attribute {
|
||||
public:
|
||||
static AffineMapAttr *get(AffineMap *value, MLIRContext *context);
|
||||
static AffineMapAttr *get(AffineMap value);
|
||||
|
||||
AffineMap *getValue() const {
|
||||
return value;
|
||||
}
|
||||
AffineMap getValue() const { return value; }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::AffineMap;
|
||||
}
|
||||
|
||||
private:
|
||||
AffineMapAttr(AffineMap *value)
|
||||
AffineMapAttr(AffineMap value)
|
||||
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
|
||||
value(value) {}
|
||||
~AffineMapAttr() = delete;
|
||||
AffineMap *value;
|
||||
AffineMap value;
|
||||
};
|
||||
|
||||
class TypeAttr : public Attribute {
|
||||
|
|
|
@ -84,7 +84,7 @@ public:
|
|||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
||||
ArrayRef<Type *> results);
|
||||
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap *> affineMapComposition = {},
|
||||
ArrayRef<AffineMap> affineMapComposition = {},
|
||||
unsigned memorySpace = 0);
|
||||
VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
|
||||
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
||||
|
@ -96,7 +96,7 @@ public:
|
|||
FloatAttr *getFloatAttr(double value);
|
||||
StringAttr *getStringAttr(StringRef bytes);
|
||||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
||||
AffineMapAttr *getAffineMapAttr(AffineMap *value);
|
||||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
||||
TypeAttr *getTypeAttr(Type *type);
|
||||
FunctionAttr *getFunctionAttr(const Function *value);
|
||||
|
||||
|
@ -105,30 +105,30 @@ public:
|
|||
AffineExpr getAffineSymbolExpr(unsigned position);
|
||||
AffineExpr getAffineConstantExpr(int64_t constant);
|
||||
|
||||
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes);
|
||||
AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes);
|
||||
|
||||
// Special cases of affine maps and integer sets
|
||||
/// Returns a single constant result affine map with 0 dimensions and 0
|
||||
/// symbols. One constant result: () -> (val).
|
||||
AffineMap *getConstantAffineMap(int64_t val);
|
||||
AffineMap getConstantAffineMap(int64_t val);
|
||||
// One dimension id identity map: (i) -> (i).
|
||||
AffineMap *getDimIdentityMap();
|
||||
AffineMap getDimIdentityMap();
|
||||
// Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
|
||||
AffineMap *getMultiDimIdentityMap(unsigned rank);
|
||||
AffineMap getMultiDimIdentityMap(unsigned rank);
|
||||
// One symbol identity map: ()[s] -> (s).
|
||||
AffineMap *getSymbolIdentityMap();
|
||||
AffineMap getSymbolIdentityMap();
|
||||
|
||||
/// Returns a map that shifts its (single) input dimension by 'shift'.
|
||||
/// (d0) -> (d0 + shift)
|
||||
AffineMap *getSingleDimShiftAffineMap(int64_t shift);
|
||||
AffineMap getSingleDimShiftAffineMap(int64_t shift);
|
||||
|
||||
/// Returns an affine map that is a translation (shift) of all result
|
||||
/// expressions in 'map' by 'shift'.
|
||||
/// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
|
||||
/// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
|
||||
AffineMap *getShiftedAffineMap(AffineMap *map, int64_t shift);
|
||||
AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
|
||||
|
||||
// Integer set.
|
||||
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
|
@ -392,8 +392,8 @@ public:
|
|||
|
||||
// Creates a for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(Location *location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap *ubMap, int64_t step = 1);
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap ubMap, int64_t step = 1);
|
||||
|
||||
// Creates a for statement with known (constant) lower and upper bounds.
|
||||
// Default step is 1.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#ifndef MLIR_IR_OPIMPLEMENTATION_H
|
||||
#define MLIR_IR_OPIMPLEMENTATION_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
|
@ -69,7 +70,7 @@ public:
|
|||
virtual void printType(const Type *type) = 0;
|
||||
virtual void printFunctionReference(const Function *func) = 0;
|
||||
virtual void printAttribute(const Attribute *attr) = 0;
|
||||
virtual void printAffineMap(AffineMap *map) = 0;
|
||||
virtual void printAffineMap(AffineMap map) = 0;
|
||||
virtual void printAffineExpr(AffineExpr expr) = 0;
|
||||
|
||||
/// If the specified operation has attributes, print out an attribute
|
||||
|
@ -104,8 +105,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
|
|||
return p;
|
||||
}
|
||||
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
|
||||
p.printAffineMap(&const_cast<AffineMap &>(map));
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) {
|
||||
p.printAffineMap(map);
|
||||
return p;
|
||||
}
|
||||
|
||||
|
|
|
@ -86,11 +86,11 @@ class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
|
|||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
/// Builds an affine apply op with the specified map and operands.
|
||||
static void build(Builder *builder, OperationState *result, AffineMap *map,
|
||||
static void build(Builder *builder, OperationState *result, AffineMap map,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
/// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
AffineMap getAffineMap() const {
|
||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#ifndef MLIR_IR_STATEMENTS_H
|
||||
#define MLIR_IR_STATEMENTS_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StmtBlock.h"
|
||||
|
@ -29,7 +30,6 @@
|
|||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class AffineBound;
|
||||
class IntegerSet;
|
||||
class AffineCondition;
|
||||
|
@ -199,8 +199,8 @@ private:
|
|||
class ForStmt : public Statement, public MLValue, public StmtBlock {
|
||||
public:
|
||||
static ForStmt *create(Location *location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap *ubMap, int64_t step, MLIRContext *context);
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap ubMap, int64_t step);
|
||||
|
||||
~ForStmt() {
|
||||
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
|
||||
|
@ -235,20 +235,20 @@ public:
|
|||
int64_t getStep() const { return step; }
|
||||
|
||||
/// Returns affine map for the lower bound.
|
||||
AffineMap *getLowerBoundMap() const { return lbMap; }
|
||||
AffineMap getLowerBoundMap() const { return lbMap; }
|
||||
/// Returns affine map for the upper bound.
|
||||
AffineMap *getUpperBoundMap() const { return ubMap; }
|
||||
AffineMap getUpperBoundMap() const { return ubMap; }
|
||||
|
||||
/// Set lower bound.
|
||||
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map);
|
||||
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap map);
|
||||
/// Set upper bound.
|
||||
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map);
|
||||
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap map);
|
||||
|
||||
/// Set the lower bound map without changing operands.
|
||||
void setLowerBoundMap(AffineMap *map);
|
||||
void setLowerBoundMap(AffineMap map);
|
||||
|
||||
/// Set the upper bound map without changing operands.
|
||||
void setUpperBoundMap(AffineMap *map);
|
||||
void setUpperBoundMap(AffineMap map);
|
||||
|
||||
/// Set loop step.
|
||||
void setStep(int64_t step) {
|
||||
|
@ -353,9 +353,9 @@ public:
|
|||
|
||||
private:
|
||||
// Affine map for the lower bound.
|
||||
AffineMap *lbMap;
|
||||
AffineMap lbMap;
|
||||
// Affine map for the upper bound.
|
||||
AffineMap *ubMap;
|
||||
AffineMap ubMap;
|
||||
// Positive constant step. Since index is stored as an int64_t, we restrict
|
||||
// step to the set of positive integers that int64_t can represent.
|
||||
int64_t step;
|
||||
|
@ -364,8 +364,8 @@ private:
|
|||
// bound.
|
||||
std::vector<StmtOperand> operands;
|
||||
|
||||
explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
|
||||
AffineMap *ubMap, int64_t step, MLIRContext *context);
|
||||
explicit ForStmt(Location *location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step);
|
||||
};
|
||||
|
||||
/// AffineBound represents a lower or upper bound in the for statement.
|
||||
|
@ -375,7 +375,7 @@ private:
|
|||
class AffineBound {
|
||||
public:
|
||||
const ForStmt *getForStmt() const { return &stmt; }
|
||||
AffineMap *getMap() const { return map; }
|
||||
AffineMap getMap() const { return map; }
|
||||
|
||||
unsigned getNumOperands() const { return opEnd - opStart; }
|
||||
const MLValue *getOperand(unsigned idx) const {
|
||||
|
@ -411,12 +411,11 @@ private:
|
|||
// the containing 'for' statement operands.
|
||||
unsigned opStart, opEnd;
|
||||
// Affine map for this bound.
|
||||
AffineMap *map;
|
||||
AffineMap map;
|
||||
|
||||
AffineBound(const ForStmt &stmt, const unsigned opStart, const unsigned opEnd,
|
||||
const AffineMap *map)
|
||||
: stmt(stmt), opStart(opStart), opEnd(opEnd),
|
||||
map(const_cast<AffineMap *>(map)) {}
|
||||
AffineBound(const ForStmt &stmt, unsigned opStart, unsigned opEnd,
|
||||
AffineMap map)
|
||||
: stmt(stmt), opStart(opStart), opEnd(opEnd), map(map) {}
|
||||
|
||||
friend class ForStmt;
|
||||
};
|
||||
|
|
|
@ -408,7 +408,7 @@ public:
|
|||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space.
|
||||
static MemRefType *get(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap*> affineMapComposition,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace);
|
||||
|
||||
unsigned getRank() const { return getShape().size(); }
|
||||
|
@ -426,7 +426,7 @@ public:
|
|||
|
||||
/// Returns an array of affine map pointers representing the memref affine
|
||||
/// map composition.
|
||||
ArrayRef<AffineMap*> getAffineMaps() const;
|
||||
ArrayRef<AffineMap> getAffineMaps() const;
|
||||
|
||||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const { return memorySpace; }
|
||||
|
@ -446,12 +446,12 @@ private:
|
|||
/// The number of affine maps in the 'affineMapList' array.
|
||||
const unsigned numAffineMaps;
|
||||
/// List of affine maps in the memref's layout/index map composition.
|
||||
AffineMap *const *const affineMapList;
|
||||
AffineMap const *affineMapList;
|
||||
/// Memory space in which data referenced by memref resides.
|
||||
const unsigned memorySpace;
|
||||
|
||||
MemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap*> affineMapList, unsigned memorySpace,
|
||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
||||
MLIRContext *context);
|
||||
~MemRefType() = delete;
|
||||
};
|
||||
|
|
|
@ -71,15 +71,15 @@ void promoteSingleIterationLoops(MLFunction *f);
|
|||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop
|
||||
/// with the specified unroll factor.
|
||||
AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
|
||||
/// Returns the upper bound of an unrolled loop when unrolling with
|
||||
/// the specified trip count, stride, and unroll factor.
|
||||
AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder);
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise delays.
|
||||
|
|
|
@ -25,11 +25,11 @@
|
|||
#ifndef MLIR_TRANSFORMS_UTILS_H
|
||||
#define MLIR_TRANSFORMS_UTILS_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class MLValue;
|
||||
class SSAValue;
|
||||
|
||||
|
@ -43,7 +43,7 @@ class SSAValue;
|
|||
// extended to add additional indices at any position.
|
||||
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
llvm::ArrayRef<SSAValue *> extraIndices,
|
||||
AffineMap *indexRemap = nullptr);
|
||||
AffineMap indexRemap = AffineMap::Invalid());
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_UTILS_H
|
||||
|
|
|
@ -166,12 +166,13 @@ forwardSubstituteMutableAffineMap(const AffineMapCompositionUpdate &mapUpdate,
|
|||
map->setNumSymbols(mapUpdate.outputNumSymbols);
|
||||
}
|
||||
|
||||
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
|
||||
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
|
||||
context(context) {
|
||||
for (auto result : map->getResults())
|
||||
MutableAffineMap::MutableAffineMap(AffineMap map)
|
||||
: numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
|
||||
// A map always has at leat 1 result by construction
|
||||
context(map.getResult(0).getContext()) {
|
||||
for (auto result : map.getResults())
|
||||
results.push_back(result);
|
||||
for (auto rangeSize : map->getRangeSizes())
|
||||
for (auto rangeSize : map.getRangeSizes())
|
||||
results.push_back(rangeSize);
|
||||
}
|
||||
|
||||
|
@ -194,7 +195,7 @@ void MutableAffineMap::simplify() {
|
|||
}
|
||||
}
|
||||
|
||||
AffineMap *MutableAffineMap::getAffineMap() {
|
||||
AffineMap MutableAffineMap::getAffineMap() {
|
||||
return AffineMap::get(numDims, numSymbols, results, rangeSizes);
|
||||
}
|
||||
|
||||
|
@ -209,17 +210,16 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
|
|||
MLIRContext *context)
|
||||
: numDims(numDims), numSymbols(numSymbols), context(context) {}
|
||||
|
||||
AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
|
||||
: map(op.getAffineMap(), context) {
|
||||
AffineValueMap::AffineValueMap(const AffineApplyOp &op)
|
||||
: map(op.getAffineMap()) {
|
||||
for (auto *operand : op.getOperands())
|
||||
operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand)));
|
||||
for (unsigned i = 0, e = op.getNumResults(); i < e; i++)
|
||||
results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i))));
|
||||
}
|
||||
|
||||
AffineValueMap::AffineValueMap(AffineMap *map, ArrayRef<MLValue *> operands,
|
||||
MLIRContext *context)
|
||||
: map(map, context) {
|
||||
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
|
||||
: map(map) {
|
||||
for (MLValue *operand : operands) {
|
||||
this->operands.push_back(operand);
|
||||
}
|
||||
|
@ -303,20 +303,20 @@ void AffineValueMap::forwardSubstitute(const AffineApplyOp &inputOp) {
|
|||
|
||||
// Gather dim and symbol positions from 'inputOp' on which
|
||||
// 'inputResultsUsed' depend.
|
||||
AffineMap *inputMap = inputOp.getAffineMap();
|
||||
unsigned inputNumDims = inputMap->getNumDims();
|
||||
AffineMap inputMap = inputOp.getAffineMap();
|
||||
unsigned inputNumDims = inputMap.getNumDims();
|
||||
DenseSet<unsigned> inputPositionsUsed;
|
||||
AffineExprPositionGatherer gatherer(inputNumDims, &inputPositionsUsed);
|
||||
for (unsigned i = 0; i < inputNumResults; ++i) {
|
||||
if (inputResultsUsed.count(i) == 0)
|
||||
continue;
|
||||
gatherer.walkPostOrder(inputMap->getResult(i));
|
||||
gatherer.walkPostOrder(inputMap.getResult(i));
|
||||
}
|
||||
|
||||
// Build new output operands list and map update.
|
||||
SmallVector<MLValue *, 4> outputOperands;
|
||||
unsigned outputOperandPosition = 0;
|
||||
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap()->getResults());
|
||||
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults());
|
||||
|
||||
// Add dim operands from current map.
|
||||
for (unsigned i = 0; i < currNumDims; ++i) {
|
||||
|
@ -405,7 +405,7 @@ ArrayRef<MLValue *> AffineValueMap::getOperands() const {
|
|||
return ArrayRef<MLValue *>(operands);
|
||||
}
|
||||
|
||||
AffineMap *AffineValueMap::getAffineMap() { return map.getAffineMap(); }
|
||||
AffineMap AffineValueMap::getAffineMap() { return map.getAffineMap(); }
|
||||
|
||||
AffineValueMap::~AffineValueMap() {}
|
||||
|
||||
|
|
|
@ -44,10 +44,10 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
|
|||
int64_t ub = forStmt.getConstantUpperBound();
|
||||
loopSpan = ub - lb + 1;
|
||||
} else {
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
auto *ubMap = forStmt.getUpperBoundMap();
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
auto ubMap = forStmt.getUpperBoundMap();
|
||||
// TODO(bondhugula): handle max/min of multiple expressions.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
|
||||
return nullptr;
|
||||
|
||||
// TODO(bondhugula): handle bounds with different operands.
|
||||
|
@ -56,11 +56,11 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
|
|||
return nullptr;
|
||||
|
||||
// ub_expr - lb_expr + 1
|
||||
AffineExpr lbExpr(lbMap->getResult(0));
|
||||
AffineExpr ubExpr(ubMap->getResult(0));
|
||||
AffineExpr lbExpr(lbMap.getResult(0));
|
||||
AffineExpr ubExpr(ubMap.getResult(0));
|
||||
auto loopSpanExpr = simplifyAffineExpr(
|
||||
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
|
||||
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
|
||||
ubExpr - lbExpr + 1, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
|
||||
std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
|
||||
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
|
||||
if (!cExpr)
|
||||
return loopSpanExpr.ceilDiv(step);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "AffineMapDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
|
@ -87,13 +88,15 @@ private:
|
|||
|
||||
} // end anonymous namespace
|
||||
|
||||
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes)
|
||||
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
||||
results(results), rangeSizes(rangeSizes) {}
|
||||
/// Returns a single constant result affine map.
|
||||
AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
|
||||
return get(/*dimCount=*/0, /*symbolCount=*/0,
|
||||
{getAffineConstantExpr(val, context)}, {});
|
||||
}
|
||||
|
||||
bool AffineMap::isIdentity() {
|
||||
bool AffineMap::isBounded() const { return !map->rangeSizes.empty(); }
|
||||
|
||||
bool AffineMap::isIdentity() const {
|
||||
if (getNumDims() != getNumResults())
|
||||
return false;
|
||||
ArrayRef<AffineExpr> results = getResults();
|
||||
|
@ -105,28 +108,35 @@ bool AffineMap::isIdentity() {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns a single constant result affine map.
|
||||
AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
|
||||
return get(/*dimCount=*/0, /*symbolCount=*/0,
|
||||
{getAffineConstantExpr(val, context)}, {});
|
||||
}
|
||||
|
||||
bool AffineMap::isSingleConstant() {
|
||||
bool AffineMap::isSingleConstant() const {
|
||||
return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
|
||||
}
|
||||
|
||||
int64_t AffineMap::getSingleConstantResult() {
|
||||
int64_t AffineMap::getSingleConstantResult() const {
|
||||
assert(isSingleConstant() && "map must have a single constant result");
|
||||
return getResult(0).cast<AffineConstantExpr>().getValue();
|
||||
}
|
||||
|
||||
AffineExpr AffineMap::getResult(unsigned idx) { return results[idx]; }
|
||||
unsigned AffineMap::getNumDims() const { return map->numDims; }
|
||||
unsigned AffineMap::getNumSymbols() const { return map->numSymbols; }
|
||||
unsigned AffineMap::getNumResults() const { return map->numResults; }
|
||||
unsigned AffineMap::getNumInputs() const {
|
||||
return map->numDims + map->numSymbols;
|
||||
}
|
||||
|
||||
ArrayRef<AffineExpr> AffineMap::getResults() const { return map->results; }
|
||||
AffineExpr AffineMap::getResult(unsigned idx) const {
|
||||
return map->results[idx];
|
||||
}
|
||||
ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
|
||||
return map->rangeSizes;
|
||||
}
|
||||
|
||||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible. Returns false if the folding happens,
|
||||
/// true otherwise.
|
||||
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results) {
|
||||
SmallVectorImpl<Attribute *> &results) const {
|
||||
assert(getNumInputs() == operandConstants.size());
|
||||
|
||||
// Fold each of the result expressions.
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
//===- AffineMapDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This holds implementation details of AffineMap.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef AFFINEMAPDETAIL_H_
|
||||
#define AFFINEMAPDETAIL_H_
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
struct AffineMapStorage {
|
||||
unsigned numDims;
|
||||
unsigned numSymbols;
|
||||
unsigned numResults;
|
||||
|
||||
/// The affine expressions for this (multi-dimensional) map.
|
||||
/// TODO: use trailing objects for this.
|
||||
ArrayRef<AffineExpr> results;
|
||||
|
||||
/// The extents along each of the range dimensions if the map is bounded,
|
||||
/// nullptr otherwise.
|
||||
ArrayRef<AffineExpr> rangeSizes;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // AFFINEMAPDETAIL_H_
|
|
@ -64,7 +64,7 @@ public:
|
|||
// Initializes module state, populating affine map state.
|
||||
void initialize(const Module *module);
|
||||
|
||||
int getAffineMapId(AffineMap *affineMap) const {
|
||||
int getAffineMapId(AffineMap affineMap) const {
|
||||
auto it = affineMapIds.find(affineMap);
|
||||
if (it == affineMapIds.end()) {
|
||||
return -1;
|
||||
|
@ -72,7 +72,7 @@ public:
|
|||
return it->second;
|
||||
}
|
||||
|
||||
ArrayRef<AffineMap *> getAffineMapIds() const { return affineMapsById; }
|
||||
ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
|
||||
|
||||
int getIntegerSetId(IntegerSet *integerSet) const {
|
||||
auto it = integerSetIds.find(integerSet);
|
||||
|
@ -85,7 +85,7 @@ public:
|
|||
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
|
||||
|
||||
private:
|
||||
void recordAffineMapReference(AffineMap *affineMap) {
|
||||
void recordAffineMapReference(AffineMap affineMap) {
|
||||
if (affineMapIds.count(affineMap) == 0) {
|
||||
affineMapIds[affineMap] = affineMapsById.size();
|
||||
affineMapsById.push_back(affineMap);
|
||||
|
@ -100,15 +100,15 @@ private:
|
|||
}
|
||||
|
||||
// Return true if this map could be printed using the shorthand form.
|
||||
static bool hasShorthandForm(AffineMap *boundMap) {
|
||||
if (boundMap->isSingleConstant())
|
||||
static bool hasShorthandForm(AffineMap boundMap) {
|
||||
if (boundMap.isSingleConstant())
|
||||
return true;
|
||||
|
||||
// Check if the affine map is single dim id or single symbol identity -
|
||||
// (i)->(i) or ()[s]->(i)
|
||||
return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 &&
|
||||
(boundMap->getResult(0).isa<AffineDimExpr>() ||
|
||||
boundMap->getResult(0).isa<AffineSymbolExpr>());
|
||||
return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 &&
|
||||
(boundMap.getResult(0).isa<AffineDimExpr>() ||
|
||||
boundMap.getResult(0).isa<AffineSymbolExpr>());
|
||||
}
|
||||
|
||||
// Visit functions.
|
||||
|
@ -124,8 +124,8 @@ private:
|
|||
void visitAttribute(const Attribute *attr);
|
||||
void visitOperation(const Operation *op);
|
||||
|
||||
DenseMap<AffineMap *, int> affineMapIds;
|
||||
std::vector<AffineMap *> affineMapsById;
|
||||
DenseMap<AffineMap, int> affineMapIds;
|
||||
std::vector<AffineMap> affineMapsById;
|
||||
|
||||
DenseMap<IntegerSet *, int> integerSetIds;
|
||||
std::vector<IntegerSet *> integerSetsById;
|
||||
|
@ -142,7 +142,7 @@ void ModuleState::visitType(const Type *type) {
|
|||
visitType(result);
|
||||
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
|
||||
// Visit affine maps in memref type.
|
||||
for (auto *map : memref->getAffineMaps()) {
|
||||
for (auto map : memref->getAffineMaps()) {
|
||||
recordAffineMapReference(map);
|
||||
}
|
||||
}
|
||||
|
@ -193,11 +193,11 @@ void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
|
|||
}
|
||||
|
||||
void ModuleState::visitForStmt(const ForStmt *forStmt) {
|
||||
AffineMap *lbMap = forStmt->getLowerBoundMap();
|
||||
AffineMap lbMap = forStmt->getLowerBoundMap();
|
||||
if (!hasShorthandForm(lbMap))
|
||||
recordAffineMapReference(lbMap);
|
||||
|
||||
AffineMap *ubMap = forStmt->getUpperBoundMap();
|
||||
AffineMap ubMap = forStmt->getUpperBoundMap();
|
||||
if (!hasShorthandForm(ubMap))
|
||||
recordAffineMapReference(ubMap);
|
||||
|
||||
|
@ -273,7 +273,7 @@ public:
|
|||
void print(const CFGFunction *fn);
|
||||
void print(const MLFunction *fn);
|
||||
|
||||
void printAffineMap(AffineMap *map);
|
||||
void printAffineMap(AffineMap map);
|
||||
void printAffineExpr(AffineExpr expr);
|
||||
void printAffineConstraint(AffineExpr expr, bool isEq);
|
||||
void printIntegerSet(IntegerSet *set);
|
||||
|
@ -288,7 +288,7 @@ protected:
|
|||
ArrayRef<const char *> elidedAttrs = {});
|
||||
void printFunctionResultType(const FunctionType *type);
|
||||
void printAffineMapId(int affineMapId) const;
|
||||
void printAffineMapReference(AffineMap *affineMap);
|
||||
void printAffineMapReference(AffineMap affineMap);
|
||||
void printIntegerSetId(int integerSetId) const;
|
||||
void printIntegerSetReference(IntegerSet *integerSet);
|
||||
|
||||
|
@ -321,14 +321,14 @@ void ModulePrinter::printAffineMapId(int affineMapId) const {
|
|||
os << "#map" << affineMapId;
|
||||
}
|
||||
|
||||
void ModulePrinter::printAffineMapReference(AffineMap *affineMap) {
|
||||
void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
|
||||
int mapId = state.getAffineMapId(affineMap);
|
||||
if (mapId >= 0) {
|
||||
// Map will be printed at top of module so print reference to its id.
|
||||
printAffineMapId(mapId);
|
||||
} else {
|
||||
// Map not in module state so print inline.
|
||||
affineMap->print(os);
|
||||
affineMap.print(os);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -352,7 +352,7 @@ void ModulePrinter::print(const Module *module) {
|
|||
for (const auto &map : state.getAffineMapIds()) {
|
||||
printAffineMapId(state.getAffineMapId(map));
|
||||
os << " = ";
|
||||
map->print(os);
|
||||
map.print(os);
|
||||
os << '\n';
|
||||
}
|
||||
for (const auto &set : state.getIntegerSetIds()) {
|
||||
|
@ -678,40 +678,40 @@ void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
|
|||
isEq ? os << " == 0" : os << " >= 0";
|
||||
}
|
||||
|
||||
void ModulePrinter::printAffineMap(AffineMap *map) {
|
||||
void ModulePrinter::printAffineMap(AffineMap map) {
|
||||
// Dimension identifiers.
|
||||
os << '(';
|
||||
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
|
||||
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
|
||||
os << 'd' << i << ", ";
|
||||
if (map->getNumDims() >= 1)
|
||||
os << 'd' << map->getNumDims() - 1;
|
||||
if (map.getNumDims() >= 1)
|
||||
os << 'd' << map.getNumDims() - 1;
|
||||
os << ')';
|
||||
|
||||
// Symbolic identifiers.
|
||||
if (map->getNumSymbols() != 0) {
|
||||
if (map.getNumSymbols() != 0) {
|
||||
os << '[';
|
||||
for (unsigned i = 0; i < map->getNumSymbols() - 1; ++i)
|
||||
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
|
||||
os << 's' << i << ", ";
|
||||
if (map->getNumSymbols() >= 1)
|
||||
os << 's' << map->getNumSymbols() - 1;
|
||||
if (map.getNumSymbols() >= 1)
|
||||
os << 's' << map.getNumSymbols() - 1;
|
||||
os << ']';
|
||||
}
|
||||
|
||||
// AffineMap should have at least one result.
|
||||
assert(!map->getResults().empty());
|
||||
assert(!map.getResults().empty());
|
||||
// Result affine expressions.
|
||||
os << " -> (";
|
||||
interleaveComma(map->getResults(),
|
||||
interleaveComma(map.getResults(),
|
||||
[&](AffineExpr expr) { printAffineExpr(expr); });
|
||||
os << ')';
|
||||
|
||||
if (!map->isBounded()) {
|
||||
if (!map.isBounded()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Print range sizes for bounded affine maps.
|
||||
os << " size (";
|
||||
interleaveComma(map->getRangeSizes(),
|
||||
interleaveComma(map.getRangeSizes(),
|
||||
[&](AffineExpr expr) { printAffineExpr(expr); });
|
||||
os << ')';
|
||||
}
|
||||
|
@ -851,7 +851,7 @@ public:
|
|||
void printAttribute(const Attribute *attr) {
|
||||
ModulePrinter::printAttribute(attr);
|
||||
}
|
||||
void printAffineMap(AffineMap *map) {
|
||||
void printAffineMap(AffineMap map) {
|
||||
return ModulePrinter::printAffineMapReference(map);
|
||||
}
|
||||
void printIntegerSet(IntegerSet *set) {
|
||||
|
@ -1422,7 +1422,7 @@ void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
|
|||
}
|
||||
|
||||
void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
||||
AffineMap *map = bound.getMap();
|
||||
AffineMap map = bound.getMap();
|
||||
|
||||
// Check if this bound should be printed using short-hand notation.
|
||||
// The decision to restrict printing short-hand notation to trivial cases
|
||||
|
@ -1430,11 +1430,11 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
// lossless way.
|
||||
// Therefore, short-hand parsing and printing is only supported for
|
||||
// zero-operand constant maps and single symbol operand identity maps.
|
||||
if (map->getNumResults() == 1) {
|
||||
AffineExpr expr = map->getResult(0);
|
||||
if (map.getNumResults() == 1) {
|
||||
AffineExpr expr = map.getResult(0);
|
||||
|
||||
// Print constant bound.
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
|
||||
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
os << constExpr.getValue();
|
||||
return;
|
||||
|
@ -1443,7 +1443,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
|
||||
// Print bound that consists of a single SSA symbol if the map is over a
|
||||
// single symbol.
|
||||
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
|
||||
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
|
||||
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
|
||||
printOperand(bound.getOperand(0));
|
||||
return;
|
||||
|
@ -1456,7 +1456,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
|
||||
// Print the map and its operands.
|
||||
printAffineMapReference(map);
|
||||
printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
|
||||
printDimAndSymbolList(bound.getStmtOperands(), map.getNumDims());
|
||||
}
|
||||
|
||||
void MLFunctionPrinter::print(const IfStmt *stmt) {
|
||||
|
@ -1496,7 +1496,7 @@ void Type::print(raw_ostream &os) const {
|
|||
|
||||
void Type::dump() const { print(llvm::errs()); }
|
||||
|
||||
void AffineMap::dump() {
|
||||
void AffineMap::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
@ -1516,9 +1516,9 @@ void AffineExpr::dump() const {
|
|||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void AffineMap::print(raw_ostream &os) {
|
||||
void AffineMap::print(raw_ostream &os) const {
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).printAffineMap(this);
|
||||
ModulePrinter(os, state).printAffineMap(*this);
|
||||
}
|
||||
|
||||
void IntegerSet::print(raw_ostream &os) {
|
||||
|
|
|
@ -90,7 +90,7 @@ FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
|
|||
}
|
||||
|
||||
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap *> affineMapComposition,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
||||
}
|
||||
|
@ -133,8 +133,8 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
|||
return ArrayAttr::get(value, context);
|
||||
}
|
||||
|
||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
|
||||
return AffineMapAttr::get(map, context);
|
||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) {
|
||||
return AffineMapAttr::get(map);
|
||||
}
|
||||
|
||||
TypeAttr *Builder::getTypeAttr(Type *type) {
|
||||
|
@ -149,9 +149,9 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
|||
// Affine Expressions, Affine Maps, and Integet Sets.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes) {
|
||||
AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes) {
|
||||
return AffineMap::get(dimCount, symbolCount, results, rangeSizes);
|
||||
}
|
||||
|
||||
|
@ -173,17 +173,17 @@ IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
|||
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
|
||||
}
|
||||
|
||||
AffineMap *Builder::getConstantAffineMap(int64_t val) {
|
||||
AffineMap Builder::getConstantAffineMap(int64_t val) {
|
||||
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
|
||||
{getAffineConstantExpr(val)}, {});
|
||||
}
|
||||
|
||||
AffineMap *Builder::getDimIdentityMap() {
|
||||
AffineMap Builder::getDimIdentityMap() {
|
||||
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
|
||||
{getAffineDimExpr(0)}, {});
|
||||
}
|
||||
|
||||
AffineMap *Builder::getMultiDimIdentityMap(unsigned rank) {
|
||||
AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
dimExprs.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
|
@ -191,25 +191,25 @@ AffineMap *Builder::getMultiDimIdentityMap(unsigned rank) {
|
|||
return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {});
|
||||
}
|
||||
|
||||
AffineMap *Builder::getSymbolIdentityMap() {
|
||||
AffineMap Builder::getSymbolIdentityMap() {
|
||||
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
|
||||
{getAffineSymbolExpr(0)}, {});
|
||||
}
|
||||
|
||||
AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
|
||||
AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
|
||||
// expr = d0 + shift.
|
||||
auto expr = getAffineDimExpr(0) + shift;
|
||||
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {});
|
||||
}
|
||||
|
||||
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
|
||||
AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
|
||||
SmallVector<AffineExpr, 4> shiftedResults;
|
||||
shiftedResults.reserve(map->getNumResults());
|
||||
for (auto resultExpr : map->getResults()) {
|
||||
shiftedResults.reserve(map.getNumResults());
|
||||
for (auto resultExpr : map.getResults()) {
|
||||
shiftedResults.push_back(resultExpr + shift);
|
||||
}
|
||||
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,
|
||||
map->getRangeSizes());
|
||||
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
|
||||
map.getRangeSizes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -278,19 +278,19 @@ OperationStmt *MLFuncBuilder::createOperation(Location *location,
|
|||
|
||||
ForStmt *MLFuncBuilder::createFor(Location *location,
|
||||
ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap *lbMap,
|
||||
AffineMap lbMap,
|
||||
ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap *ubMap, int64_t step) {
|
||||
auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap,
|
||||
step, context);
|
||||
AffineMap ubMap, int64_t step) {
|
||||
auto *stmt =
|
||||
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
}
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(Location *location, int64_t lb, int64_t ub,
|
||||
int64_t step) {
|
||||
auto *lbMap = AffineMap::getConstantMap(lb, context);
|
||||
auto *ubMap = AffineMap::getConstantMap(ub, context);
|
||||
auto lbMap = AffineMap::getConstantMap(lb, context);
|
||||
auto ubMap = AffineMap::getConstantMap(ub, context);
|
||||
return createFor(location, {}, lbMap, {}, ubMap, step);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "AffineExprDetail.h"
|
||||
#include "AffineMapDetail.h"
|
||||
#include "AttributeListStorage.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
@ -59,13 +60,13 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
|
|||
}
|
||||
};
|
||||
|
||||
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
||||
struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
|
||||
// Affine maps are uniqued based on their dim/symbol counts and affine
|
||||
// expressions.
|
||||
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>,
|
||||
ArrayRef<AffineExpr>>;
|
||||
using DenseMapInfo<AffineMap *>::getHashValue;
|
||||
using DenseMapInfo<AffineMap *>::isEqual;
|
||||
using DenseMapInfo<AffineMap>::getHashValue;
|
||||
using DenseMapInfo<AffineMap>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
|
@ -74,11 +75,11 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
|||
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, AffineMap *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
||||
rhs->getResults(), rhs->getRangeSizes());
|
||||
return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
|
||||
rhs.getResults(), rhs.getRangeSizes());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -124,7 +125,7 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
|||
// MemRefs are uniqued based on their element type, shape, affine map
|
||||
// composition, and memory space.
|
||||
using KeyTy =
|
||||
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap *>, unsigned>;
|
||||
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
|
||||
using DenseMapInfo<MemRefType *>::getHashValue;
|
||||
using DenseMapInfo<MemRefType *>::isEqual;
|
||||
|
||||
|
@ -222,7 +223,7 @@ public:
|
|||
nullptr};
|
||||
|
||||
// Affine map uniquing.
|
||||
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
|
||||
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
|
||||
AffineMapSet affineMaps;
|
||||
|
||||
// Affine binary op expression uniquing. Figure out uniquing of dimensional
|
||||
|
@ -267,7 +268,7 @@ public:
|
|||
StringMap<StringAttr *> stringAttrs;
|
||||
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
||||
ArrayAttrSet arrayAttrs;
|
||||
DenseMap<AffineMap *, AffineMapAttr *> affineMapAttrs;
|
||||
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs;
|
||||
DenseMap<Type *, TypeAttr *> typeAttrs;
|
||||
using AttributeListSet =
|
||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||
|
@ -558,7 +559,7 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
|||
}
|
||||
|
||||
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap *> affineMapComposition,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace) {
|
||||
auto *context = elementType->getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
@ -581,7 +582,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
|||
// Copy the affine map composition into the bump pointer.
|
||||
// TODO(andydavis) Assert that the structure of the composition is valid.
|
||||
affineMapComposition =
|
||||
impl.copyInto(ArrayRef<AffineMap *>(affineMapComposition));
|
||||
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
|
||||
|
@ -675,8 +676,9 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
AffineMapAttr *AffineMapAttr::get(AffineMap *value, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().affineMapAttrs[value];
|
||||
AffineMapAttr *AffineMapAttr::get(AffineMap value) {
|
||||
auto *context = value.getResult(0).getContext();
|
||||
auto &result = context->getImpl().affineMapAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
|
@ -802,9 +804,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
// AffineMap and AffineExpr uniquing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes) {
|
||||
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> results,
|
||||
ArrayRef<AffineExpr> rangeSizes) {
|
||||
// The number of results can't be zero.
|
||||
assert(!results.empty());
|
||||
|
||||
|
@ -814,25 +816,26 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
|||
|
||||
// Check if we already have this affine map.
|
||||
auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
|
||||
auto existing = impl.affineMaps.insert_as(nullptr, key);
|
||||
auto existing = impl.affineMaps.insert_as(AffineMap(nullptr), key);
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *res = impl.allocator.Allocate<AffineMap>();
|
||||
auto *res = impl.allocator.Allocate<detail::AffineMapStorage>();
|
||||
|
||||
// Copy the results and range sizes into the bump pointer.
|
||||
results = impl.copyInto(results);
|
||||
rangeSizes = impl.copyInto(rangeSizes);
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (res)
|
||||
AffineMap(dimCount, symbolCount, results.size(), results, rangeSizes);
|
||||
new (res) detail::AffineMapStorage{dimCount, symbolCount,
|
||||
static_cast<unsigned>(results.size()),
|
||||
results, rangeSizes};
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = res;
|
||||
return *existing.first = AffineMap(res);
|
||||
}
|
||||
|
||||
/// Simplify add expression. Return nullptr if it can't be simplified.
|
||||
|
|
|
@ -101,9 +101,9 @@ Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AffineApplyOp::build(Builder *builder, OperationState *result,
|
||||
AffineMap *map, ArrayRef<SSAValue *> operands) {
|
||||
AffineMap map, ArrayRef<SSAValue *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->types.append(map->getNumResults(), builder->getIndexType());
|
||||
result->types.append(map.getNumResults(), builder->getIndexType());
|
||||
result->addAttribute("map", builder->getAffineMapAttr(map));
|
||||
}
|
||||
|
||||
|
@ -117,22 +117,22 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes))
|
||||
return true;
|
||||
auto *map = mapAttr->getValue();
|
||||
auto map = mapAttr->getValue();
|
||||
|
||||
if (map->getNumDims() != numDims ||
|
||||
numDims + map->getNumSymbols() != result->operands.size()) {
|
||||
if (map.getNumDims() != numDims ||
|
||||
numDims + map.getNumSymbols() != result->operands.size()) {
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"dimension or symbol index mismatch");
|
||||
}
|
||||
|
||||
result->types.append(map->getNumResults(), affineIntTy);
|
||||
result->types.append(map.getNumResults(), affineIntTy);
|
||||
return false;
|
||||
}
|
||||
|
||||
void AffineApplyOp::print(OpAsmPrinter *p) const {
|
||||
auto *map = getAffineMap();
|
||||
*p << "affine_apply " << *map;
|
||||
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
|
||||
auto map = getAffineMap();
|
||||
*p << "affine_apply " << map;
|
||||
printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
||||
}
|
||||
|
||||
|
@ -143,15 +143,15 @@ bool AffineApplyOp::verify() const {
|
|||
return emitOpError("requires an affine map");
|
||||
|
||||
// Check input and output dimensions match.
|
||||
auto *map = affineMapAttr->getValue();
|
||||
auto map = affineMapAttr->getValue();
|
||||
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
|
||||
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
|
||||
return emitOpError(
|
||||
"operand count and affine map dimension and symbol count must match");
|
||||
|
||||
// Verify that result count matches affine map result count.
|
||||
if (getNumResults() != map->getNumResults())
|
||||
if (getNumResults() != map.getNumResults())
|
||||
return emitOpError("result count and affine map result count must match");
|
||||
|
||||
return false;
|
||||
|
@ -183,8 +183,8 @@ bool AffineApplyOp::isValidSymbol() const {
|
|||
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results,
|
||||
MLIRContext *context) const {
|
||||
auto *map = getAffineMap();
|
||||
if (map->constantFold(operandConstants, results))
|
||||
auto map = getAffineMap();
|
||||
if (map.constantFold(operandConstants, results))
|
||||
return true;
|
||||
// Return false on success.
|
||||
return false;
|
||||
|
@ -243,11 +243,11 @@ bool AllocOp::verify() const {
|
|||
|
||||
unsigned numSymbols = 0;
|
||||
if (!memRefType->getAffineMaps().empty()) {
|
||||
AffineMap *affineMap = memRefType->getAffineMaps()[0];
|
||||
AffineMap affineMap = memRefType->getAffineMaps()[0];
|
||||
// Store number of symbols used in affine map (used in subsequent check).
|
||||
numSymbols = affineMap->getNumSymbols();
|
||||
numSymbols = affineMap.getNumSymbols();
|
||||
// Verify that the layout affine map matches the rank of the memref.
|
||||
if (affineMap->getNumDims() != memRefType->getRank())
|
||||
if (affineMap.getNumDims() != memRefType->getRank())
|
||||
return emitOpError("affine map dimension count must equal memref rank");
|
||||
}
|
||||
unsigned numDynamicDims = memRefType->getNumDynamicDims();
|
||||
|
|
|
@ -262,17 +262,16 @@ bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap *ubMap, int64_t step, MLIRContext *context) {
|
||||
assert(lbOperands.size() == lbMap->getNumInputs() &&
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
assert(lbOperands.size() == lbMap.getNumInputs() &&
|
||||
"lower bound operand count does not match the affine map");
|
||||
assert(ubOperands.size() == ubMap->getNumInputs() &&
|
||||
assert(ubOperands.size() == ubMap.getNumInputs() &&
|
||||
"upper bound operand count does not match the affine map");
|
||||
assert(step > 0 && "step has to be a positive integer constant");
|
||||
|
||||
unsigned numOperands = lbOperands.size() + ubOperands.size();
|
||||
ForStmt *stmt =
|
||||
new ForStmt(location, numOperands, lbMap, ubMap, step, context);
|
||||
ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
|
||||
|
||||
unsigned i = 0;
|
||||
for (unsigned e = lbOperands.size(); i != e; ++i)
|
||||
|
@ -284,30 +283,31 @@ ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
|
|||
return stmt;
|
||||
}
|
||||
|
||||
ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
|
||||
AffineMap *ubMap, int64_t step, MLIRContext *context)
|
||||
ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step)
|
||||
: Statement(Kind::For, location),
|
||||
MLValue(MLValueKind::ForStmt, Type::getIndex(context)),
|
||||
MLValue(MLValueKind::ForStmt,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
operands.reserve(numOperands);
|
||||
}
|
||||
|
||||
const AffineBound ForStmt::getLowerBound() const {
|
||||
return AffineBound(*this, 0, lbMap->getNumInputs(), lbMap);
|
||||
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
|
||||
}
|
||||
|
||||
const AffineBound ForStmt::getUpperBound() const {
|
||||
return AffineBound(*this, lbMap->getNumInputs(), getNumOperands(), ubMap);
|
||||
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
|
||||
}
|
||||
|
||||
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
|
||||
assert(lbOperands.size() == map->getNumInputs());
|
||||
assert(map->getNumResults() >= 1 && "bound map has at least one result");
|
||||
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
|
||||
assert(lbOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands());
|
||||
|
||||
operands.clear();
|
||||
operands.reserve(lbOperands.size() + ubMap->getNumInputs());
|
||||
operands.reserve(lbOperands.size() + ubMap.getNumInputs());
|
||||
for (auto *operand : lbOperands) {
|
||||
operands.emplace_back(StmtOperand(this, operand));
|
||||
}
|
||||
|
@ -317,9 +317,9 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
|
|||
this->lbMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
|
||||
assert(ubOperands.size() == map->getNumInputs());
|
||||
assert(map->getNumResults() >= 1 && "bound map has at least one result");
|
||||
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) {
|
||||
assert(ubOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands());
|
||||
|
||||
|
@ -334,34 +334,30 @@ void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
|
|||
this->ubMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setLowerBoundMap(AffineMap *map) {
|
||||
assert(lbMap->getNumDims() == map->getNumDims() &&
|
||||
lbMap->getNumSymbols() == map->getNumSymbols());
|
||||
assert(map->getNumResults() >= 1 && "bound map has at least one result");
|
||||
void ForStmt::setLowerBoundMap(AffineMap map) {
|
||||
assert(lbMap.getNumDims() == map.getNumDims() &&
|
||||
lbMap.getNumSymbols() == map.getNumSymbols());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
this->lbMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setUpperBoundMap(AffineMap *map) {
|
||||
assert(ubMap->getNumDims() == map->getNumDims() &&
|
||||
ubMap->getNumSymbols() == map->getNumSymbols());
|
||||
assert(map->getNumResults() >= 1 && "bound map has at least one result");
|
||||
void ForStmt::setUpperBoundMap(AffineMap map) {
|
||||
assert(ubMap.getNumDims() == map.getNumDims() &&
|
||||
ubMap.getNumSymbols() == map.getNumSymbols());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
this->ubMap = map;
|
||||
}
|
||||
|
||||
bool ForStmt::hasConstantLowerBound() const {
|
||||
return lbMap->isSingleConstant();
|
||||
}
|
||||
bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
|
||||
|
||||
bool ForStmt::hasConstantUpperBound() const {
|
||||
return ubMap->isSingleConstant();
|
||||
}
|
||||
bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
|
||||
|
||||
int64_t ForStmt::getConstantLowerBound() const {
|
||||
return lbMap->getSingleConstantResult();
|
||||
return lbMap.getSingleConstantResult();
|
||||
}
|
||||
|
||||
int64_t ForStmt::getConstantUpperBound() const {
|
||||
return ubMap->getSingleConstantResult();
|
||||
return ubMap.getSingleConstantResult();
|
||||
}
|
||||
|
||||
void ForStmt::setConstantLowerBound(int64_t value) {
|
||||
|
@ -373,21 +369,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
|
|||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
|
||||
return {operand_begin(),
|
||||
operand_begin() + getLowerBoundMap()->getNumInputs()};
|
||||
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
|
||||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
|
||||
return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()};
|
||||
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
|
||||
}
|
||||
|
||||
bool ForStmt::matchingBoundOperandList() const {
|
||||
if (lbMap->getNumDims() != ubMap->getNumDims() ||
|
||||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
|
||||
if (lbMap.getNumDims() != ubMap.getNumDims() ||
|
||||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
|
||||
return false;
|
||||
|
||||
unsigned numOperands = lbMap->getNumInputs();
|
||||
for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
|
||||
unsigned numOperands = lbMap.getNumInputs();
|
||||
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
|
||||
// Compare MLValue *'s.
|
||||
if (getOperand(i) != getOperand(numOperands + i))
|
||||
return false;
|
||||
|
@ -419,11 +414,11 @@ bool ForStmt::constantFoldBound(bool lower) {
|
|||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
AffineMap *boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
|
||||
assert(boundMap->getNumResults() >= 1 &&
|
||||
AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute *, 4> foldedResults;
|
||||
if (boundMap->constantFold(operandConstants, foldedResults))
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return true;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
|
@ -523,14 +518,14 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
|||
}
|
||||
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
|
||||
auto *lbMap = forStmt->getLowerBoundMap();
|
||||
auto *ubMap = forStmt->getUpperBoundMap();
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
auto *newFor = ForStmt::create(
|
||||
getLoc(),
|
||||
ArrayRef<MLValue *>(operands).take_front(lbMap->getNumInputs()), lbMap,
|
||||
ArrayRef<MLValue *>(operands).take_back(ubMap->getNumInputs()), ubMap,
|
||||
forStmt->getStep(), context);
|
||||
ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap,
|
||||
ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap,
|
||||
forStmt->getStep());
|
||||
|
||||
// Remember the induction variable mapping.
|
||||
operandMap[forStmt] = newFor;
|
||||
|
|
|
@ -77,14 +77,14 @@ UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
|||
}
|
||||
|
||||
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap *> affineMapList,
|
||||
unsigned memorySpace, MLIRContext *context)
|
||||
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
|
||||
MLIRContext *context)
|
||||
: Type(Kind::MemRef, context, shape.size()), elementType(elementType),
|
||||
shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
|
||||
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
|
||||
|
||||
ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
|
||||
return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
|
||||
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
|
||||
}
|
||||
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
|
|
|
@ -69,7 +69,7 @@ public:
|
|||
}
|
||||
|
||||
// A map from affine map identifier to AffineMap.
|
||||
llvm::StringMap<AffineMap *> affineMapDefinitions;
|
||||
llvm::StringMap<AffineMap> affineMapDefinitions;
|
||||
|
||||
// A map from integer set identifier to IntegerSet.
|
||||
llvm::StringMap<IntegerSet *> integerSetDefinitions;
|
||||
|
@ -200,8 +200,8 @@ public:
|
|||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||
|
||||
// Polyhedral structures.
|
||||
AffineMap *parseAffineMapInline();
|
||||
AffineMap *parseAffineMapReference();
|
||||
AffineMap parseAffineMapInline();
|
||||
AffineMap parseAffineMapReference();
|
||||
IntegerSet *parseIntegerSetInline();
|
||||
IntegerSet *parseIntegerSetReference();
|
||||
|
||||
|
@ -521,7 +521,7 @@ Type *Parser::parseMemRefType() {
|
|||
return (emitError(typeLoc, "invalid memref element type"), nullptr);
|
||||
|
||||
// Parse semi-affine-map-composition.
|
||||
SmallVector<AffineMap *, 2> affineMapComposition;
|
||||
SmallVector<AffineMap, 2> affineMapComposition;
|
||||
unsigned memorySpace = 0;
|
||||
bool parsedMemorySpace = false;
|
||||
|
||||
|
@ -540,8 +540,8 @@ Type *Parser::parseMemRefType() {
|
|||
// Parse affine map.
|
||||
if (parsedMemorySpace)
|
||||
return emitError("affine map after memory space in memref type");
|
||||
auto *affineMap = parseAffineMapReference();
|
||||
if (affineMap == nullptr)
|
||||
auto affineMap = parseAffineMapReference();
|
||||
if (!affineMap)
|
||||
return ParseFailure;
|
||||
affineMapComposition.push_back(affineMap);
|
||||
}
|
||||
|
@ -728,7 +728,7 @@ Attribute *Parser::parseAttribute() {
|
|||
case Token::hash_identifier:
|
||||
case Token::l_paren: {
|
||||
// Try to parse affine map reference.
|
||||
if (auto *affineMap = parseAffineMapReference())
|
||||
if (auto affineMap = parseAffineMapReference())
|
||||
return builder.getAffineMapAttr(affineMap);
|
||||
return (emitError("expected constant attribute value"), nullptr);
|
||||
}
|
||||
|
@ -827,7 +827,7 @@ class AffineParser : public Parser {
|
|||
public:
|
||||
explicit AffineParser(ParserState &state) : Parser(state) {}
|
||||
|
||||
AffineMap *parseAffineMapInline();
|
||||
AffineMap parseAffineMapInline();
|
||||
IntegerSet *parseIntegerSetInline();
|
||||
|
||||
private:
|
||||
|
@ -1223,22 +1223,22 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
|
|||
/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
|
||||
///
|
||||
/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
|
||||
AffineMap *AffineParser::parseAffineMapInline() {
|
||||
AffineMap AffineParser::parseAffineMapInline() {
|
||||
unsigned numDims = 0, numSymbols = 0;
|
||||
|
||||
// List of dimensional identifiers.
|
||||
if (parseDimIdList(numDims))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
// Symbols are optional.
|
||||
if (getToken().is(Token::l_square)) {
|
||||
if (parseSymbolIdList(numSymbols))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
}
|
||||
|
||||
if (parseToken(Token::arrow, "expected '->' or '['") ||
|
||||
parseToken(Token::l_paren, "expected '(' at start of affine map range"))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
|
@ -1252,7 +1252,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
|||
// affine expressions); the list cannot be empty.
|
||||
// Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
|
||||
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
// Parse optional range sizes.
|
||||
// range-sizes ::= (`size` `(` dim-size (`,` dim-size)* `)`)?
|
||||
|
@ -1264,7 +1264,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
|||
// Location of the l_paren token (if it exists) for error reporting later.
|
||||
auto loc = getToken().getLoc();
|
||||
if (parseToken(Token::l_paren, "expected '(' at start of affine map range"))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
auto parseRangeSize = [&]() -> ParseResult {
|
||||
auto loc = getToken().getLoc();
|
||||
|
@ -1281,30 +1281,30 @@ AffineMap *AffineParser::parseAffineMapInline() {
|
|||
};
|
||||
|
||||
if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false))
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
if (exprs.size() > rangeSizes.size())
|
||||
return (emitError(loc, "fewer range sizes than range expressions"),
|
||||
nullptr);
|
||||
AffineMap::Invalid());
|
||||
if (exprs.size() < rangeSizes.size())
|
||||
return (emitError(loc, "more range sizes than range expressions"),
|
||||
nullptr);
|
||||
AffineMap::Invalid());
|
||||
}
|
||||
|
||||
// Parsed a valid affine map.
|
||||
return builder.getAffineMap(numDims, numSymbols, exprs, rangeSizes);
|
||||
}
|
||||
|
||||
AffineMap *Parser::parseAffineMapInline() {
|
||||
AffineMap Parser::parseAffineMapInline() {
|
||||
return AffineParser(state).parseAffineMapInline();
|
||||
}
|
||||
|
||||
AffineMap *Parser::parseAffineMapReference() {
|
||||
AffineMap Parser::parseAffineMapReference() {
|
||||
if (getToken().is(Token::hash_identifier)) {
|
||||
// Parse affine map identifier and verify that it exists.
|
||||
StringRef affineMapId = getTokenSpelling().drop_front();
|
||||
if (getState().affineMapDefinitions.count(affineMapId) == 0)
|
||||
return (emitError("undefined affine map id '" + affineMapId + "'"),
|
||||
nullptr);
|
||||
AffineMap::Invalid());
|
||||
consumeToken(Token::hash_identifier);
|
||||
return getState().affineMapDefinitions[affineMapId];
|
||||
}
|
||||
|
@ -2221,7 +2221,7 @@ private:
|
|||
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
||||
unsigned numDims, unsigned numOperands,
|
||||
const char *affineStructName);
|
||||
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap *&map,
|
||||
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
|
||||
bool isLower);
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseElseClause(IfClause *elseClause);
|
||||
|
@ -2261,7 +2261,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
|
||||
// Parse lower bound.
|
||||
SmallVector<MLValue *, 4> lbOperands;
|
||||
AffineMap *lbMap = nullptr;
|
||||
AffineMap lbMap;
|
||||
if (parseBound(lbOperands, lbMap, /*isLower*/ true))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -2270,7 +2270,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
|
||||
// Parse upper bound.
|
||||
SmallVector<MLValue *, 4> ubOperands;
|
||||
AffineMap *ubMap = nullptr;
|
||||
AffineMap ubMap;
|
||||
if (parseBound(ubOperands, ubMap, /*isLower*/ false))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -2388,7 +2388,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
|||
/// shorthand-bound ::= ssa-id | `-`? integer-literal
|
||||
///
|
||||
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
|
||||
AffineMap *&map, bool isLower) {
|
||||
AffineMap &map, bool isLower) {
|
||||
// 'min' / 'max' prefixes are syntactic sugar. Ignore them.
|
||||
if (isLower)
|
||||
consumeIf(Token::kw_max);
|
||||
|
@ -2401,7 +2401,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
|
|||
if (!map)
|
||||
return ParseFailure;
|
||||
|
||||
if (parseDimAndSymbolList(operands, map->getNumDims(), map->getNumInputs(),
|
||||
if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(),
|
||||
"affine map"))
|
||||
return ParseFailure;
|
||||
return ParseSuccess;
|
||||
|
@ -2691,7 +2691,7 @@ ParseResult ModuleParser::parseAffineMapDef() {
|
|||
StringRef affineMapId = getTokenSpelling().drop_front();
|
||||
|
||||
// Check for redefinitions.
|
||||
auto *&entry = getState().affineMapDefinitions[affineMapId];
|
||||
auto &entry = getState().affineMapDefinitions[affineMapId];
|
||||
if (entry)
|
||||
return emitError("redefinition of affine map id '" + affineMapId + "'");
|
||||
|
||||
|
|
|
@ -103,9 +103,9 @@ static void createComposedAffineApplyOp(
|
|||
unsigned rank = memrefType->getRank();
|
||||
assert(indices.size() == rank);
|
||||
// Create identity map with same number of dimensions as 'memrefType'.
|
||||
auto *map = builder->getMultiDimIdentityMap(rank);
|
||||
auto map = builder->getMultiDimIdentityMap(rank);
|
||||
// Initialize AffineValueMap with identity map.
|
||||
AffineValueMap valueMap(map, indices, builder->getContext());
|
||||
AffineValueMap valueMap(map, indices);
|
||||
|
||||
for (auto *opStmt : affineApplyOps) {
|
||||
assert(opStmt->is<AffineApplyOp>());
|
||||
|
|
|
@ -201,14 +201,14 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
if (unrollFactor == 1 || forStmt->getStatements().empty())
|
||||
return false;
|
||||
|
||||
auto *lbMap = forStmt->getLowerBoundMap();
|
||||
auto *ubMap = forStmt->getUpperBoundMap();
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as an MLFunction in the general case). However, the right way to
|
||||
// do such unrolling for an MLFunction would be to specialize the loop for the
|
||||
// 'hotspot' case and unroll that hotspot.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Same operand list for lower and upper bound for now.
|
||||
|
@ -229,7 +229,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
assert(clLbMap &&
|
||||
"cleanup loop lower bound map for single result bound maps can "
|
||||
"always be determined");
|
||||
|
@ -238,7 +238,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
promoteIfSingleIteration(cleanupForStmt);
|
||||
|
||||
// Adjust upper bound.
|
||||
auto *unrolledUbMap =
|
||||
auto unrolledUbMap =
|
||||
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
|
||||
assert(unrolledUbMap &&
|
||||
"upper bound map can alwayys be determined for an unrolled loop "
|
||||
|
@ -267,7 +267,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
if (!forStmt->use_empty()) {
|
||||
// iv' = iv + 1/2/3...unrollFactor-1;
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
->getResult(0);
|
||||
|
|
|
@ -156,14 +156,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
|
||||
return false;
|
||||
|
||||
auto *lbMap = forStmt->getLowerBoundMap();
|
||||
auto *ubMap = forStmt->getUpperBoundMap();
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as an MLFunction in the general case). However, the right way to
|
||||
// do such unrolling for an MLFunction would be to specialize the loop for the
|
||||
// 'hotspot' case and unroll that hotspot.
|
||||
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
|
||||
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Same operand list for lower and upper bound for now.
|
||||
|
@ -221,7 +221,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
if (!forStmt->use_empty()) {
|
||||
// iv' = iv + i, i = 1 to unrollJamFactor-1.
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
->getResult(0);
|
||||
|
|
|
@ -35,50 +35,50 @@ using namespace mlir;
|
|||
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
|
||||
/// the specified trip count, stride, and unroll factor. Returns nullptr when
|
||||
/// the trip count can't be expressed as an affine expression.
|
||||
AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap->getNumResults() != 1)
|
||||
return nullptr;
|
||||
if (lbMap.getNumResults() != 1)
|
||||
return AffineMap::Invalid();
|
||||
|
||||
// Sometimes, the trip count cannot be expressed as an affine expression.
|
||||
auto tripCount = getTripCountExpr(forStmt);
|
||||
if (!tripCount)
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
AffineExpr lb(lbMap->getResult(0));
|
||||
AffineExpr lb(lbMap.getResult(0));
|
||||
unsigned step = forStmt.getStep();
|
||||
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
|
||||
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
|
||||
{newUb}, {});
|
||||
}
|
||||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
|
||||
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
|
||||
/// Returns nullptr when the trip count can't be expressed as an affine
|
||||
/// expression.
|
||||
AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto *lbMap = forStmt.getLowerBoundMap();
|
||||
/// Returns an AffinMap with nullptr storage (that evaluates to false)
|
||||
/// when the trip count can't be expressed as an affine expression.
|
||||
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
unsigned unrollFactor,
|
||||
MLFuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap->getNumResults() != 1)
|
||||
return nullptr;
|
||||
if (lbMap.getNumResults() != 1)
|
||||
return AffineMap::Invalid();
|
||||
|
||||
// Sometimes the trip count cannot be expressed as an affine expression.
|
||||
AffineExpr tripCount(getTripCountExpr(forStmt));
|
||||
if (!tripCount)
|
||||
return nullptr;
|
||||
return AffineMap::Invalid();
|
||||
|
||||
AffineExpr lb(lbMap->getResult(0));
|
||||
AffineExpr lb(lbMap.getResult(0));
|
||||
unsigned step = forStmt.getStep();
|
||||
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
|
||||
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
|
||||
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
|
||||
{newLb}, {});
|
||||
}
|
||||
|
||||
|
@ -91,7 +91,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
|||
return false;
|
||||
|
||||
// TODO(mlir-team): there is no builder for a max.
|
||||
if (forStmt->getLowerBoundMap()->getNumResults() != 1)
|
||||
if (forStmt->getLowerBoundMap().getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Replaces all IV uses to its single iteration value.
|
||||
|
@ -140,7 +140,7 @@ void mlir::promoteSingleIterationLoops(MLFunction *f) {
|
|||
/// the pair specifies the delay applied to that group of statements. Returns
|
||||
/// nullptr if the generated loop simplifies to a single iteration one.
|
||||
static ForStmt *
|
||||
generateLoop(AffineMap *lb, AffineMap *ub,
|
||||
generateLoop(AffineMap lb, AffineMap ub,
|
||||
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
|
||||
&stmtGroupQueue,
|
||||
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
|
||||
|
@ -296,7 +296,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
|
|||
// of statements is paired with its delay.
|
||||
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
|
||||
|
||||
auto *origLbMap = forStmt->getLowerBoundMap();
|
||||
auto origLbMap = forStmt->getLowerBoundMap();
|
||||
uint64_t lbDelay = 0;
|
||||
MLFuncBuilder b(forStmt);
|
||||
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
|
||||
|
|
|
@ -100,7 +100,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
->getResult());
|
||||
|
||||
auto d0 = bInner.getAffineDimExpr(0);
|
||||
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
|
||||
auto modTwoMap =
|
||||
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
|
||||
auto ivModTwoOp =
|
||||
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
|
||||
|
|
|
@ -59,10 +59,10 @@ PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
|
|||
void visitOperationStmt(OperationStmt *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
|
||||
MutableAffineMap mMap(mapAttr->getValue(), context);
|
||||
MutableAffineMap mMap(mapAttr->getValue());
|
||||
mMap.simplify();
|
||||
auto *map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
|
||||
auto map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,14 +48,14 @@ static bool isMemRefDereferencingOp(const Operation &op) {
|
|||
// extended to add additional indices at any position.
|
||||
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
ArrayRef<SSAValue *> extraIndices,
|
||||
AffineMap *indexRemap) {
|
||||
AffineMap indexRemap) {
|
||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
|
||||
(void)newMemRefRank;
|
||||
if (indexRemap) {
|
||||
assert(indexRemap->getNumInputs() == oldMemRefRank);
|
||||
assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
assert(indexRemap.getNumInputs() == oldMemRefRank);
|
||||
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
} else {
|
||||
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue