forked from OSchip/llvm-project
Parse affine map range sizes.
PiperOrigin-RevId: 204240947
This commit is contained in:
parent
b488a035aa
commit
8fbaf79afb
|
@ -41,7 +41,14 @@ class AffineExpr;
|
|||
class AffineMap {
|
||||
public:
|
||||
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr *> results, MLIRContext *context);
|
||||
ArrayRef<AffineExpr *> results,
|
||||
ArrayRef<AffineExpr *> rangeSizes,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Returns true if the co-domain (or more loosely speaking, range) of this
|
||||
/// map is bounded. Bounded affine maps have a size (extent) for each of
|
||||
/// their range dimensions (more accurately co-domain dimensions).
|
||||
bool isBounded() const { return rangeSizes != nullptr; }
|
||||
|
||||
// Prints affine map to 'os'.
|
||||
void print(raw_ostream &os) const;
|
||||
|
@ -55,12 +62,17 @@ public:
|
|||
return ArrayRef<AffineExpr *>(results, numResults);
|
||||
}
|
||||
|
||||
private:
|
||||
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||
AffineExpr *const *results);
|
||||
ArrayRef<AffineExpr *> getRangeSizes() const {
|
||||
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
|
||||
: ArrayRef<AffineExpr *>();
|
||||
}
|
||||
|
||||
AffineMap(const AffineMap&) = delete;
|
||||
void operator=(const AffineMap&) = delete;
|
||||
private:
|
||||
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||
AffineExpr *const *results, AffineExpr *const *rangeSizes);
|
||||
|
||||
AffineMap(const AffineMap &) = delete;
|
||||
void operator=(const AffineMap &) = delete;
|
||||
|
||||
const unsigned numDims;
|
||||
const unsigned numSymbols;
|
||||
|
@ -69,6 +81,10 @@ public:
|
|||
/// The affine expressions for this (multi-dimensional) map.
|
||||
/// TODO: use trailing objects for this.
|
||||
AffineExpr *const *const results;
|
||||
|
||||
/// The extents along each of the range dimensions if the map is bounded,
|
||||
/// nullptr otherwise.
|
||||
AffineExpr *const *const rangeSizes;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -75,7 +75,8 @@ public:
|
|||
|
||||
// Affine Expressions and Affine Map.
|
||||
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr *> results);
|
||||
ArrayRef<AffineExpr *> results,
|
||||
ArrayRef<AffineExpr *> rangeSizes);
|
||||
AffineDimExpr *getDimExpr(unsigned position);
|
||||
AffineSymbolExpr *getSymbolExpr(unsigned position);
|
||||
AffineConstantExpr *getConstantExpr(int64_t constant);
|
||||
|
|
|
@ -23,9 +23,9 @@
|
|||
using namespace mlir;
|
||||
|
||||
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||
AffineExpr *const *results)
|
||||
AffineExpr *const *results, AffineExpr *const *rangeSizes)
|
||||
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
||||
results(results) {}
|
||||
results(results), rangeSizes(rangeSizes) {}
|
||||
|
||||
/// Fold to a constant when possible. Canonicalize so that only the RHS is a
|
||||
/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
|
||||
|
|
|
@ -392,6 +392,17 @@ void AffineMap::print(raw_ostream &os) const {
|
|||
os << " -> (";
|
||||
interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
|
||||
[&]() { os << ", "; });
|
||||
os << ")";
|
||||
|
||||
if (!isBounded()) {
|
||||
os << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Print range sizes for bounded affine maps.
|
||||
os << " size (";
|
||||
interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
|
||||
[&]() { os << ", "; });
|
||||
os << ")\n";
|
||||
}
|
||||
|
||||
|
|
|
@ -99,8 +99,9 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr *> results) {
|
||||
return AffineMap::get(dimCount, symbolCount, results, context);
|
||||
ArrayRef<AffineExpr *> results,
|
||||
ArrayRef<AffineExpr *> rangeSizes) {
|
||||
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
|
||||
}
|
||||
|
||||
AffineDimExpr *Builder::getDimExpr(unsigned position) {
|
||||
|
|
|
@ -55,21 +55,23 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
|
|||
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
||||
// Affine maps are uniqued based on their dim/symbol counts and affine
|
||||
// expressions.
|
||||
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
|
||||
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
|
||||
ArrayRef<AffineExpr *>>;
|
||||
using DenseMapInfo<AffineMap *>::getHashValue;
|
||||
using DenseMapInfo<AffineMap *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
std::get<0>(key), std::get<1>(key),
|
||||
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
|
||||
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
|
||||
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
||||
rhs->getResults());
|
||||
rhs->getResults(), rhs->getRangeSizes());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -555,14 +557,17 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
|
||||
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr *> results,
|
||||
ArrayRef<AffineExpr *> rangeSizes,
|
||||
MLIRContext *context) {
|
||||
// The number of results can't be zero.
|
||||
assert(!results.empty());
|
||||
|
||||
assert(rangeSizes.empty() || results.size() == rangeSizes.size());
|
||||
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Check if we already have this affine map.
|
||||
auto key = std::make_tuple(dimCount, symbolCount, results);
|
||||
auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
|
||||
auto existing = impl.affineMaps.insert_as(nullptr, key);
|
||||
|
||||
// If we already have it, return that value.
|
||||
|
@ -575,8 +580,12 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
|||
// Copy the results into the bump pointer.
|
||||
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
|
||||
|
||||
// Copy the results into the bump pointer.
|
||||
rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data());
|
||||
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
|
||||
rangeSizes.empty() ? nullptr : rangeSizes.data());
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = res;
|
||||
|
|
|
@ -628,6 +628,11 @@ private:
|
|||
unsigned getNumDims() const { return dims.size(); }
|
||||
unsigned getNumSymbols() const { return symbols.size(); }
|
||||
|
||||
/// Returns true if the only identifiers the parser accepts in affine
|
||||
/// expressions are symbolic identifiers.
|
||||
bool isPureSymbolic() const { return pureSymbolic; }
|
||||
void setSymbolicParsing(bool val) { pureSymbolic = val; }
|
||||
|
||||
// Binary affine op parsing.
|
||||
AffineLowPrecOp consumeIfLowPrecOp();
|
||||
AffineHighPrecOp consumeIfHighPrecOp();
|
||||
|
@ -657,6 +662,9 @@ private:
|
|||
// TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
|
||||
llvm::StringMap<unsigned> dims;
|
||||
llvm::StringMap<unsigned> symbols;
|
||||
/// True if the parser should allow only symbolic identifiers in affine
|
||||
/// expressions.
|
||||
bool pureSymbolic = false;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -664,6 +672,7 @@ private:
|
|||
AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
||||
AffineExpr *lhs,
|
||||
AffineExpr *rhs) {
|
||||
// TODO: make the error location info accurate.
|
||||
switch (op) {
|
||||
case Mul:
|
||||
if (!lhs->isSymbolic() && !rhs->isSymbolic()) {
|
||||
|
@ -828,15 +837,21 @@ AffineExpr *AffineMapParser::parseBareIdExpr() {
|
|||
return (emitError("expected bare identifier"), nullptr);
|
||||
|
||||
StringRef sRef = getTokenSpelling();
|
||||
// dims, symbols are all pairwise distinct.
|
||||
if (dims.count(sRef)) {
|
||||
if (isPureSymbolic())
|
||||
return (emitError("identifier used is not a symbolic identifier"),
|
||||
nullptr);
|
||||
consumeToken(Token::bare_identifier);
|
||||
return builder.getDimExpr(dims.lookup(sRef));
|
||||
}
|
||||
|
||||
if (symbols.count(sRef)) {
|
||||
consumeToken(Token::bare_identifier);
|
||||
return builder.getSymbolExpr(symbols.lookup(sRef));
|
||||
}
|
||||
return (emitError("identifier is neither dimensional nor symbolic"), nullptr);
|
||||
|
||||
return (emitError("use of undeclared identifier"), nullptr);
|
||||
}
|
||||
|
||||
/// Parse a positive integral constant appearing in an affine expression.
|
||||
|
@ -1053,8 +1068,36 @@ AffineMap *AffineMapParser::parseAffineMapInline() {
|
|||
if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
|
||||
return nullptr;
|
||||
|
||||
// Parse optional range sizes.
|
||||
// (`size` `(` dim-size (`,` dim-size)* `)`)?
|
||||
// TODO: check if sizes are non-negative whenever they are constant.
|
||||
SmallVector<AffineExpr *, 4> rangeSizes;
|
||||
if (consumeIf(Token::kw_size)) {
|
||||
// Location of the l_paren token (if it exists) for error reporting later.
|
||||
auto loc = getToken().getLoc();
|
||||
if (!consumeIf(Token::l_paren))
|
||||
return (emitError("expected '(' at start of affine map range"), nullptr);
|
||||
|
||||
auto parseRangeSize = [&]() -> ParseResult {
|
||||
auto *elt = parseAffineExpr();
|
||||
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
||||
rangeSizes.push_back(elt);
|
||||
return res;
|
||||
};
|
||||
|
||||
setSymbolicParsing(true);
|
||||
if (parseCommaSeparatedList(Token::r_paren, parseRangeSize, false))
|
||||
return nullptr;
|
||||
if (exprs.size() > rangeSizes.size())
|
||||
return (emitError(loc, "fewer range sizes than range expressions"),
|
||||
nullptr);
|
||||
if (exprs.size() < rangeSizes.size())
|
||||
return (emitError(loc, "more range sizes than range expressions"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Parsed a valid affine map.
|
||||
return builder.getAffineMap(dims.size(), symbols.size(), exprs);
|
||||
return builder.getAffineMap(dims.size(), symbols.size(), exprs, rangeSizes);
|
||||
}
|
||||
|
||||
AffineMap *Parser::parseAffineMapInline() {
|
||||
|
|
|
@ -103,6 +103,7 @@ TOK_KEYWORD(memref)
|
|||
TOK_KEYWORD(mlfunc)
|
||||
TOK_KEYWORD(mod)
|
||||
TOK_KEYWORD(return)
|
||||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(true)
|
||||
TOK_KEYWORD(vector)
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#hello_world = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) [s0] -> (x) ; expected-error {{identifier is neither dimensional nor symbolic}}
|
||||
#hello_world = (i, j) [s0] -> (x) ; expected-error {{use of undeclared identifier}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
|
||||
|
@ -98,7 +98,22 @@
|
|||
#hello_world = (i, j) [s0, s1] -> (-1*i j, j) ; expected-error {{expected ',' or ')'}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{identifier is neither dimensional nor symbolic}}
|
||||
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{use of undeclared identifier}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) -> (i, j) size (10, x) ; expected-error {{use of undeclared identifier}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) [M] -> (i, j) size (10, j) ; expected-error {{identifier used is not a symbolic identifier}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) [M] -> (i, j) size (10, M+i) ; expected-error {{identifier used is not a symbolic identifier}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) -> (i, j) size (10) ; expected-error {{fewer range sizes than range expressions}}
|
||||
|
||||
; -----
|
||||
#hello_world = (i, j) -> (i, j) size (10, 20, 30) ; expected-error {{more range sizes than range expressions}}
|
||||
|
||||
; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete
|
||||
|
||||
|
|
|
@ -110,3 +110,12 @@
|
|||
|
||||
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
|
||||
#hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
|
||||
|
||||
; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
|
||||
#hello_world40 = (i, j) -> (i, j) size (10, 20)
|
||||
|
||||
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
|
||||
#hello_world41 = (i, j) [N, M] -> (i, j) size (N, M+10)
|
||||
|
||||
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
|
||||
#hello_world42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
|
||||
|
|
Loading…
Reference in New Issue