Parse affine map range sizes.

PiperOrigin-RevId: 204240947
This commit is contained in:
Uday Bondhugula 2018-07-11 21:31:07 -07:00 committed by jpienaar
parent b488a035aa
commit 8fbaf79afb
10 changed files with 126 additions and 20 deletions

View File

@ -41,7 +41,14 @@ class AffineExpr;
class AffineMap {
public:
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results, MLIRContext *context);
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes,
MLIRContext *context);
/// Returns true if the co-domain (or more loosely speaking, range) of this
/// map is bounded. Bounded affine maps have a size (extent) for each of
/// their range dimensions (more accurately co-domain dimensions).
bool isBounded() const { return rangeSizes != nullptr; }
// Prints affine map to 'os'.
void print(raw_ostream &os) const;
@ -55,12 +62,17 @@ public:
return ArrayRef<AffineExpr *>(results, numResults);
}
private:
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
AffineExpr *const *results);
ArrayRef<AffineExpr *> getRangeSizes() const {
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
: ArrayRef<AffineExpr *>();
}
AffineMap(const AffineMap&) = delete;
void operator=(const AffineMap&) = delete;
private:
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
AffineExpr *const *results, AffineExpr *const *rangeSizes);
AffineMap(const AffineMap &) = delete;
void operator=(const AffineMap &) = delete;
const unsigned numDims;
const unsigned numSymbols;
@ -69,6 +81,10 @@ public:
/// The affine expressions for this (multi-dimensional) map.
/// TODO: use trailing objects for this.
AffineExpr *const *const results;
/// The extents along each of the range dimensions if the map is bounded,
/// nullptr otherwise.
AffineExpr *const *const rangeSizes;
};
} // end namespace mlir

View File

@ -75,7 +75,8 @@ public:
// Affine Expressions and Affine Map.
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results);
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes);
AffineDimExpr *getDimExpr(unsigned position);
AffineSymbolExpr *getSymbolExpr(unsigned position);
AffineConstantExpr *getConstantExpr(int64_t constant);

View File

@ -23,9 +23,9 @@
using namespace mlir;
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
AffineExpr *const *results)
AffineExpr *const *results, AffineExpr *const *rangeSizes)
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results) {}
results(results), rangeSizes(rangeSizes) {}
/// Fold to a constant when possible. Canonicalize so that only the RHS is a
/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic

View File

@ -392,6 +392,17 @@ void AffineMap::print(raw_ostream &os) const {
os << " -> (";
interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
[&]() { os << ", "; });
os << ")";
if (!isBounded()) {
os << "\n";
return;
}
// Print range sizes for bounded affine maps.
os << " size (";
interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
[&]() { os << ", "; });
os << ")\n";
}

View File

@ -99,8 +99,9 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
//===----------------------------------------------------------------------===//
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results) {
return AffineMap::get(dimCount, symbolCount, results, context);
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes) {
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
}
AffineDimExpr *Builder::getDimExpr(unsigned position) {

View File

@ -55,21 +55,23 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
ArrayRef<AffineExpr *>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
std::get<0>(key), std::get<1>(key),
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
}
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
rhs->getResults());
rhs->getResults(), rhs->getRangeSizes());
}
};
@ -555,14 +557,17 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
ArrayRef<AffineExpr *> rangeSizes,
MLIRContext *context) {
// The number of results can't be zero.
assert(!results.empty());
assert(rangeSizes.empty() || results.size() == rangeSizes.size());
auto &impl = context->getImpl();
// Check if we already have this affine map.
auto key = std::make_tuple(dimCount, symbolCount, results);
auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
auto existing = impl.affineMaps.insert_as(nullptr, key);
// If we already have it, return that value.
@ -575,8 +580,12 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
// Copy the results into the bump pointer.
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
// Copy the results into the bump pointer.
rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
// Initialize the memory using placement new.
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data());
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
rangeSizes.empty() ? nullptr : rangeSizes.data());
// Cache and return it.
return *existing.first = res;

View File

@ -628,6 +628,11 @@ private:
unsigned getNumDims() const { return dims.size(); }
unsigned getNumSymbols() const { return symbols.size(); }
/// Returns true if the only identifiers the parser accepts in affine
/// expressions are symbolic identifiers.
bool isPureSymbolic() const { return pureSymbolic; }
void setSymbolicParsing(bool val) { pureSymbolic = val; }
// Binary affine op parsing.
AffineLowPrecOp consumeIfLowPrecOp();
AffineHighPrecOp consumeIfHighPrecOp();
@ -657,6 +662,9 @@ private:
// TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
llvm::StringMap<unsigned> dims;
llvm::StringMap<unsigned> symbols;
/// True if the parser should allow only symbolic identifiers in affine
/// expressions.
bool pureSymbolic = false;
};
} // end anonymous namespace
@ -664,6 +672,7 @@ private:
AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
AffineExpr *lhs,
AffineExpr *rhs) {
// TODO: make the error location info accurate.
switch (op) {
case Mul:
if (!lhs->isSymbolic() && !rhs->isSymbolic()) {
@ -828,15 +837,21 @@ AffineExpr *AffineMapParser::parseBareIdExpr() {
return (emitError("expected bare identifier"), nullptr);
StringRef sRef = getTokenSpelling();
// dims, symbols are all pairwise distinct.
if (dims.count(sRef)) {
if (isPureSymbolic())
return (emitError("identifier used is not a symbolic identifier"),
nullptr);
consumeToken(Token::bare_identifier);
return builder.getDimExpr(dims.lookup(sRef));
}
if (symbols.count(sRef)) {
consumeToken(Token::bare_identifier);
return builder.getSymbolExpr(symbols.lookup(sRef));
}
return (emitError("identifier is neither dimensional nor symbolic"), nullptr);
return (emitError("use of undeclared identifier"), nullptr);
}
/// Parse a positive integral constant appearing in an affine expression.
@ -1053,8 +1068,36 @@ AffineMap *AffineMapParser::parseAffineMapInline() {
if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
return nullptr;
// Parse optional range sizes.
// (`size` `(` dim-size (`,` dim-size)* `)`)?
// TODO: check if sizes are non-negative whenever they are constant.
SmallVector<AffineExpr *, 4> rangeSizes;
if (consumeIf(Token::kw_size)) {
// Location of the l_paren token (if it exists) for error reporting later.
auto loc = getToken().getLoc();
if (!consumeIf(Token::l_paren))
return (emitError("expected '(' at start of affine map range"), nullptr);
auto parseRangeSize = [&]() -> ParseResult {
auto *elt = parseAffineExpr();
ParseResult res = elt ? ParseSuccess : ParseFailure;
rangeSizes.push_back(elt);
return res;
};
setSymbolicParsing(true);
if (parseCommaSeparatedList(Token::r_paren, parseRangeSize, false))
return nullptr;
if (exprs.size() > rangeSizes.size())
return (emitError(loc, "fewer range sizes than range expressions"),
nullptr);
if (exprs.size() < rangeSizes.size())
return (emitError(loc, "more range sizes than range expressions"),
nullptr);
}
// Parsed a valid affine map.
return builder.getAffineMap(dims.size(), symbols.size(), exprs);
return builder.getAffineMap(dims.size(), symbols.size(), exprs, rangeSizes);
}
AffineMap *Parser::parseAffineMapInline() {

View File

@ -103,6 +103,7 @@ TOK_KEYWORD(memref)
TOK_KEYWORD(mlfunc)
TOK_KEYWORD(mod)
TOK_KEYWORD(return)
TOK_KEYWORD(size)
TOK_KEYWORD(tensor)
TOK_KEYWORD(true)
TOK_KEYWORD(vector)

View File

@ -33,7 +33,7 @@
#hello_world = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
; -----
#hello_world = (i, j) [s0] -> (x) ; expected-error {{identifier is neither dimensional nor symbolic}}
#hello_world = (i, j) [s0] -> (x) ; expected-error {{use of undeclared identifier}}
; -----
#hello_world = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
@ -98,7 +98,22 @@
#hello_world = (i, j) [s0, s1] -> (-1*i j, j) ; expected-error {{expected ',' or ')'}}
; -----
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{identifier is neither dimensional nor symbolic}}
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{use of undeclared identifier}}
; -----
#hello_world = (i, j) -> (i, j) size (10, x) ; expected-error {{use of undeclared identifier}}
; -----
#hello_world = (i, j) [M] -> (i, j) size (10, j) ; expected-error {{identifier used is not a symbolic identifier}}
; -----
#hello_world = (i, j) [M] -> (i, j) size (10, M+i) ; expected-error {{identifier used is not a symbolic identifier}}
; -----
#hello_world = (i, j) -> (i, j) size (10) ; expected-error {{fewer range sizes than range expressions}}
; -----
#hello_world = (i, j) -> (i, j) size (10, 20, 30) ; expected-error {{more range sizes than range expressions}}
; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete

View File

@ -110,3 +110,12 @@
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
#hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
#hello_world40 = (i, j) -> (i, j) size (10, 20)
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
#hello_world41 = (i, j) [N, M] -> (i, j) size (N, M+10)
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
#hello_world42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)