forked from OSchip/llvm-project
Parse affine map range sizes.
PiperOrigin-RevId: 204240947
This commit is contained in:
parent
b488a035aa
commit
8fbaf79afb
|
@ -41,7 +41,14 @@ class AffineExpr;
|
||||||
class AffineMap {
|
class AffineMap {
|
||||||
public:
|
public:
|
||||||
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results, MLIRContext *context);
|
ArrayRef<AffineExpr *> results,
|
||||||
|
ArrayRef<AffineExpr *> rangeSizes,
|
||||||
|
MLIRContext *context);
|
||||||
|
|
||||||
|
/// Returns true if the co-domain (or more loosely speaking, range) of this
|
||||||
|
/// map is bounded. Bounded affine maps have a size (extent) for each of
|
||||||
|
/// their range dimensions (more accurately co-domain dimensions).
|
||||||
|
bool isBounded() const { return rangeSizes != nullptr; }
|
||||||
|
|
||||||
// Prints affine map to 'os'.
|
// Prints affine map to 'os'.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
|
@ -55,12 +62,17 @@ public:
|
||||||
return ArrayRef<AffineExpr *>(results, numResults);
|
return ArrayRef<AffineExpr *>(results, numResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
ArrayRef<AffineExpr *> getRangeSizes() const {
|
||||||
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
return rangeSizes ? ArrayRef<AffineExpr *>(rangeSizes, numResults)
|
||||||
AffineExpr *const *results);
|
: ArrayRef<AffineExpr *>();
|
||||||
|
}
|
||||||
|
|
||||||
AffineMap(const AffineMap&) = delete;
|
private:
|
||||||
void operator=(const AffineMap&) = delete;
|
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||||
|
AffineExpr *const *results, AffineExpr *const *rangeSizes);
|
||||||
|
|
||||||
|
AffineMap(const AffineMap &) = delete;
|
||||||
|
void operator=(const AffineMap &) = delete;
|
||||||
|
|
||||||
const unsigned numDims;
|
const unsigned numDims;
|
||||||
const unsigned numSymbols;
|
const unsigned numSymbols;
|
||||||
|
@ -69,6 +81,10 @@ public:
|
||||||
/// The affine expressions for this (multi-dimensional) map.
|
/// The affine expressions for this (multi-dimensional) map.
|
||||||
/// TODO: use trailing objects for this.
|
/// TODO: use trailing objects for this.
|
||||||
AffineExpr *const *const results;
|
AffineExpr *const *const results;
|
||||||
|
|
||||||
|
/// The extents along each of the range dimensions if the map is bounded,
|
||||||
|
/// nullptr otherwise.
|
||||||
|
AffineExpr *const *const rangeSizes;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -75,7 +75,8 @@ public:
|
||||||
|
|
||||||
// Affine Expressions and Affine Map.
|
// Affine Expressions and Affine Map.
|
||||||
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results);
|
ArrayRef<AffineExpr *> results,
|
||||||
|
ArrayRef<AffineExpr *> rangeSizes);
|
||||||
AffineDimExpr *getDimExpr(unsigned position);
|
AffineDimExpr *getDimExpr(unsigned position);
|
||||||
AffineSymbolExpr *getSymbolExpr(unsigned position);
|
AffineSymbolExpr *getSymbolExpr(unsigned position);
|
||||||
AffineConstantExpr *getConstantExpr(int64_t constant);
|
AffineConstantExpr *getConstantExpr(int64_t constant);
|
||||||
|
|
|
@ -23,9 +23,9 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
|
||||||
AffineExpr *const *results)
|
AffineExpr *const *results, AffineExpr *const *rangeSizes)
|
||||||
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
|
||||||
results(results) {}
|
results(results), rangeSizes(rangeSizes) {}
|
||||||
|
|
||||||
/// Fold to a constant when possible. Canonicalize so that only the RHS is a
|
/// Fold to a constant when possible. Canonicalize so that only the RHS is a
|
||||||
/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
|
/// constant. (4 + d0 becomes d0 + 4). If only one of them is a symbolic
|
||||||
|
|
|
@ -392,6 +392,17 @@ void AffineMap::print(raw_ostream &os) const {
|
||||||
os << " -> (";
|
os << " -> (";
|
||||||
interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
|
interleave(getResults(), [&](AffineExpr *expr) { os << *expr; },
|
||||||
[&]() { os << ", "; });
|
[&]() { os << ", "; });
|
||||||
|
os << ")";
|
||||||
|
|
||||||
|
if (!isBounded()) {
|
||||||
|
os << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print range sizes for bounded affine maps.
|
||||||
|
os << " size (";
|
||||||
|
interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
|
||||||
|
[&]() { os << ", "; });
|
||||||
os << ")\n";
|
os << ")\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,8 +99,9 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results) {
|
ArrayRef<AffineExpr *> results,
|
||||||
return AffineMap::get(dimCount, symbolCount, results, context);
|
ArrayRef<AffineExpr *> rangeSizes) {
|
||||||
|
return AffineMap::get(dimCount, symbolCount, results, rangeSizes, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineDimExpr *Builder::getDimExpr(unsigned position) {
|
AffineDimExpr *Builder::getDimExpr(unsigned position) {
|
||||||
|
|
|
@ -55,21 +55,23 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
|
||||||
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
|
||||||
// Affine maps are uniqued based on their dim/symbol counts and affine
|
// Affine maps are uniqued based on their dim/symbol counts and affine
|
||||||
// expressions.
|
// expressions.
|
||||||
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
|
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>,
|
||||||
|
ArrayRef<AffineExpr *>>;
|
||||||
using DenseMapInfo<AffineMap *>::getHashValue;
|
using DenseMapInfo<AffineMap *>::getHashValue;
|
||||||
using DenseMapInfo<AffineMap *>::isEqual;
|
using DenseMapInfo<AffineMap *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
std::get<0>(key), std::get<1>(key),
|
std::get<0>(key), std::get<1>(key),
|
||||||
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
|
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
|
||||||
|
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
|
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
|
||||||
rhs->getResults());
|
rhs->getResults(), rhs->getRangeSizes());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -555,14 +557,17 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
||||||
|
|
||||||
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||||
ArrayRef<AffineExpr *> results,
|
ArrayRef<AffineExpr *> results,
|
||||||
|
ArrayRef<AffineExpr *> rangeSizes,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
// The number of results can't be zero.
|
// The number of results can't be zero.
|
||||||
assert(!results.empty());
|
assert(!results.empty());
|
||||||
|
|
||||||
|
assert(rangeSizes.empty() || results.size() == rangeSizes.size());
|
||||||
|
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Check if we already have this affine map.
|
// Check if we already have this affine map.
|
||||||
auto key = std::make_tuple(dimCount, symbolCount, results);
|
auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
|
||||||
auto existing = impl.affineMaps.insert_as(nullptr, key);
|
auto existing = impl.affineMaps.insert_as(nullptr, key);
|
||||||
|
|
||||||
// If we already have it, return that value.
|
// If we already have it, return that value.
|
||||||
|
@ -575,8 +580,12 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
|
||||||
// Copy the results into the bump pointer.
|
// Copy the results into the bump pointer.
|
||||||
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
|
results = impl.copyInto(ArrayRef<AffineExpr *>(results));
|
||||||
|
|
||||||
|
// Copy the results into the bump pointer.
|
||||||
|
rangeSizes = impl.copyInto(ArrayRef<AffineExpr *>(rangeSizes));
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data());
|
new (res) AffineMap(dimCount, symbolCount, results.size(), results.data(),
|
||||||
|
rangeSizes.empty() ? nullptr : rangeSizes.data());
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = res;
|
return *existing.first = res;
|
||||||
|
|
|
@ -628,6 +628,11 @@ private:
|
||||||
unsigned getNumDims() const { return dims.size(); }
|
unsigned getNumDims() const { return dims.size(); }
|
||||||
unsigned getNumSymbols() const { return symbols.size(); }
|
unsigned getNumSymbols() const { return symbols.size(); }
|
||||||
|
|
||||||
|
/// Returns true if the only identifiers the parser accepts in affine
|
||||||
|
/// expressions are symbolic identifiers.
|
||||||
|
bool isPureSymbolic() const { return pureSymbolic; }
|
||||||
|
void setSymbolicParsing(bool val) { pureSymbolic = val; }
|
||||||
|
|
||||||
// Binary affine op parsing.
|
// Binary affine op parsing.
|
||||||
AffineLowPrecOp consumeIfLowPrecOp();
|
AffineLowPrecOp consumeIfLowPrecOp();
|
||||||
AffineHighPrecOp consumeIfHighPrecOp();
|
AffineHighPrecOp consumeIfHighPrecOp();
|
||||||
|
@ -657,6 +662,9 @@ private:
|
||||||
// TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
|
// TODO(bondhugula): could just use an vector/ArrayRef and scan the numbers.
|
||||||
llvm::StringMap<unsigned> dims;
|
llvm::StringMap<unsigned> dims;
|
||||||
llvm::StringMap<unsigned> symbols;
|
llvm::StringMap<unsigned> symbols;
|
||||||
|
/// True if the parser should allow only symbolic identifiers in affine
|
||||||
|
/// expressions.
|
||||||
|
bool pureSymbolic = false;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
@ -664,6 +672,7 @@ private:
|
||||||
AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
AffineExpr *AffineMapParser::getBinaryAffineOpExpr(AffineHighPrecOp op,
|
||||||
AffineExpr *lhs,
|
AffineExpr *lhs,
|
||||||
AffineExpr *rhs) {
|
AffineExpr *rhs) {
|
||||||
|
// TODO: make the error location info accurate.
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case Mul:
|
case Mul:
|
||||||
if (!lhs->isSymbolic() && !rhs->isSymbolic()) {
|
if (!lhs->isSymbolic() && !rhs->isSymbolic()) {
|
||||||
|
@ -828,15 +837,21 @@ AffineExpr *AffineMapParser::parseBareIdExpr() {
|
||||||
return (emitError("expected bare identifier"), nullptr);
|
return (emitError("expected bare identifier"), nullptr);
|
||||||
|
|
||||||
StringRef sRef = getTokenSpelling();
|
StringRef sRef = getTokenSpelling();
|
||||||
|
// dims, symbols are all pairwise distinct.
|
||||||
if (dims.count(sRef)) {
|
if (dims.count(sRef)) {
|
||||||
|
if (isPureSymbolic())
|
||||||
|
return (emitError("identifier used is not a symbolic identifier"),
|
||||||
|
nullptr);
|
||||||
consumeToken(Token::bare_identifier);
|
consumeToken(Token::bare_identifier);
|
||||||
return builder.getDimExpr(dims.lookup(sRef));
|
return builder.getDimExpr(dims.lookup(sRef));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (symbols.count(sRef)) {
|
if (symbols.count(sRef)) {
|
||||||
consumeToken(Token::bare_identifier);
|
consumeToken(Token::bare_identifier);
|
||||||
return builder.getSymbolExpr(symbols.lookup(sRef));
|
return builder.getSymbolExpr(symbols.lookup(sRef));
|
||||||
}
|
}
|
||||||
return (emitError("identifier is neither dimensional nor symbolic"), nullptr);
|
|
||||||
|
return (emitError("use of undeclared identifier"), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a positive integral constant appearing in an affine expression.
|
/// Parse a positive integral constant appearing in an affine expression.
|
||||||
|
@ -1053,8 +1068,36 @@ AffineMap *AffineMapParser::parseAffineMapInline() {
|
||||||
if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
|
if (parseCommaSeparatedList(Token::r_paren, parseElt, false))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
// Parse optional range sizes.
|
||||||
|
// (`size` `(` dim-size (`,` dim-size)* `)`)?
|
||||||
|
// TODO: check if sizes are non-negative whenever they are constant.
|
||||||
|
SmallVector<AffineExpr *, 4> rangeSizes;
|
||||||
|
if (consumeIf(Token::kw_size)) {
|
||||||
|
// Location of the l_paren token (if it exists) for error reporting later.
|
||||||
|
auto loc = getToken().getLoc();
|
||||||
|
if (!consumeIf(Token::l_paren))
|
||||||
|
return (emitError("expected '(' at start of affine map range"), nullptr);
|
||||||
|
|
||||||
|
auto parseRangeSize = [&]() -> ParseResult {
|
||||||
|
auto *elt = parseAffineExpr();
|
||||||
|
ParseResult res = elt ? ParseSuccess : ParseFailure;
|
||||||
|
rangeSizes.push_back(elt);
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
setSymbolicParsing(true);
|
||||||
|
if (parseCommaSeparatedList(Token::r_paren, parseRangeSize, false))
|
||||||
|
return nullptr;
|
||||||
|
if (exprs.size() > rangeSizes.size())
|
||||||
|
return (emitError(loc, "fewer range sizes than range expressions"),
|
||||||
|
nullptr);
|
||||||
|
if (exprs.size() < rangeSizes.size())
|
||||||
|
return (emitError(loc, "more range sizes than range expressions"),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Parsed a valid affine map.
|
// Parsed a valid affine map.
|
||||||
return builder.getAffineMap(dims.size(), symbols.size(), exprs);
|
return builder.getAffineMap(dims.size(), symbols.size(), exprs, rangeSizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMap *Parser::parseAffineMapInline() {
|
AffineMap *Parser::parseAffineMapInline() {
|
||||||
|
|
|
@ -103,6 +103,7 @@ TOK_KEYWORD(memref)
|
||||||
TOK_KEYWORD(mlfunc)
|
TOK_KEYWORD(mlfunc)
|
||||||
TOK_KEYWORD(mod)
|
TOK_KEYWORD(mod)
|
||||||
TOK_KEYWORD(return)
|
TOK_KEYWORD(return)
|
||||||
|
TOK_KEYWORD(size)
|
||||||
TOK_KEYWORD(tensor)
|
TOK_KEYWORD(tensor)
|
||||||
TOK_KEYWORD(true)
|
TOK_KEYWORD(true)
|
||||||
TOK_KEYWORD(vector)
|
TOK_KEYWORD(vector)
|
||||||
|
|
|
@ -33,7 +33,7 @@
|
||||||
#hello_world = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
|
#hello_world = (i, j) [s0] -> i + s0, j) ; expected-error {{expected '(' at start of affine map range}}
|
||||||
|
|
||||||
; -----
|
; -----
|
||||||
#hello_world = (i, j) [s0] -> (x) ; expected-error {{identifier is neither dimensional nor symbolic}}
|
#hello_world = (i, j) [s0] -> (x) ; expected-error {{use of undeclared identifier}}
|
||||||
|
|
||||||
; -----
|
; -----
|
||||||
#hello_world = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
|
#hello_world = (i, j, i) [s0] -> (i) ; expected-error {{dimensional identifier name reused}}
|
||||||
|
@ -98,7 +98,22 @@
|
||||||
#hello_world = (i, j) [s0, s1] -> (-1*i j, j) ; expected-error {{expected ',' or ')'}}
|
#hello_world = (i, j) [s0, s1] -> (-1*i j, j) ; expected-error {{expected ',' or ')'}}
|
||||||
|
|
||||||
; -----
|
; -----
|
||||||
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{identifier is neither dimensional nor symbolic}}
|
#hello_world = (i, j) -> (i, 3*d0 + ) ; expected-error {{use of undeclared identifier}}
|
||||||
|
|
||||||
|
; -----
|
||||||
|
#hello_world = (i, j) -> (i, j) size (10, x) ; expected-error {{use of undeclared identifier}}
|
||||||
|
|
||||||
|
; -----
|
||||||
|
#hello_world = (i, j) [M] -> (i, j) size (10, j) ; expected-error {{identifier used is not a symbolic identifier}}
|
||||||
|
|
||||||
|
; -----
|
||||||
|
#hello_world = (i, j) [M] -> (i, j) size (10, M+i) ; expected-error {{identifier used is not a symbolic identifier}}
|
||||||
|
|
||||||
|
; -----
|
||||||
|
#hello_world = (i, j) -> (i, j) size (10) ; expected-error {{fewer range sizes than range expressions}}
|
||||||
|
|
||||||
|
; -----
|
||||||
|
#hello_world = (i, j) -> (i, j) size (10, 20, 30) ; expected-error {{more range sizes than range expressions}}
|
||||||
|
|
||||||
; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete
|
; TODO(bondhugula): Add more tests; coverage of error messages emitted not complete
|
||||||
|
|
||||||
|
|
|
@ -110,3 +110,12 @@
|
||||||
|
|
||||||
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
|
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
|
||||||
#hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
|
#hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
|
||||||
|
|
||||||
|
; CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
|
||||||
|
#hello_world40 = (i, j) -> (i, j) size (10, 20)
|
||||||
|
|
||||||
|
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
|
||||||
|
#hello_world41 = (i, j) [N, M] -> (i, j) size (N, M+10)
|
||||||
|
|
||||||
|
; CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
|
||||||
|
#hello_world42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
|
||||||
|
|
Loading…
Reference in New Issue