forked from OSchip/llvm-project
[MLIR] IntegerSet value type
This CL applies the same pattern as AffineMap to IntegerSet: a simple struct that acts as the storage is allocated in the bump pointer. The IntegerSet is immutable and accessed everywhere by value. Note that unlike AffineMap, it is not possible to remove the MLIRContext parameter when constructing an IntegerSet for now. One possible way to achieve this would be to add an enum to distinguish between the mathematically empty set, the universe set and other sets. This is left for future discussion. PiperOrigin-RevId: 216545361
This commit is contained in:
parent
5e3cca906a
commit
b04f881dcb
|
@ -78,7 +78,7 @@ private:
|
|||
/// A mutable integer set. Its affine expressions are however unique.
|
||||
struct MutableIntegerSet {
|
||||
public:
|
||||
MutableIntegerSet(IntegerSet *set, MLIRContext *context);
|
||||
MutableIntegerSet(IntegerSet set, MLIRContext *context);
|
||||
|
||||
/// Create a universal set (no constraints).
|
||||
MutableIntegerSet(unsigned numDims, unsigned numSymbols,
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "llvm/ADT/ilist.h"
|
||||
#include "llvm/ADT/ilist_node.h"
|
||||
|
||||
|
@ -95,7 +96,7 @@ public:
|
|||
HyperRectangularSet(unsigned numDims, unsigned numSymbols,
|
||||
ArrayRef<ArrayRef<AffineExpr>> lbs,
|
||||
ArrayRef<ArrayRef<AffineExpr>> ubs, MLIRContext *context,
|
||||
IntegerSet *symbolContext = nullptr);
|
||||
IntegerSet symbolContext = IntegerSet());
|
||||
|
||||
unsigned getNumDims() const { return numDims; }
|
||||
unsigned getNumSymbols() const { return numSymbols; }
|
||||
|
|
|
@ -84,6 +84,7 @@ public:
|
|||
}
|
||||
|
||||
bool operator==(AffineExpr other) const { return expr == other.expr; }
|
||||
bool operator!=(AffineExpr other) const { return !(*this == other); }
|
||||
explicit operator bool() const { return expr; }
|
||||
|
||||
bool operator!() const { return expr == nullptr; }
|
||||
|
|
|
@ -132,9 +132,9 @@ public:
|
|||
AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
|
||||
|
||||
// Integer set.
|
||||
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> isEq);
|
||||
IntegerSet getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> isEq);
|
||||
// TODO: Helpers for affine map/exprs, etc.
|
||||
protected:
|
||||
MLIRContext *context;
|
||||
|
@ -402,7 +402,7 @@ public:
|
|||
|
||||
/// Creates if statement.
|
||||
IfStmt *createIf(Location *location, ArrayRef<MLValue *> operands,
|
||||
IntegerSet *set);
|
||||
IntegerSet set);
|
||||
|
||||
private:
|
||||
StmtBlock *block = nullptr;
|
||||
|
|
|
@ -38,55 +38,81 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
namespace detail {
|
||||
struct IntegerSetStorage;
|
||||
}
|
||||
|
||||
class MLIRContext;
|
||||
|
||||
/// An integer set representing a conjunction of affine equalities and
|
||||
/// inequalities. An integer set in the IR is immutable like the affine map, but
|
||||
/// integer sets are not unique'd. The affine expressions that make up the
|
||||
/// equalities and inequalities of an integer set are themselves unique.
|
||||
/// equalities and inequalities of an integer set are themselves unique and live
|
||||
/// in the bump allocator.
|
||||
class IntegerSet {
|
||||
public:
|
||||
static IntegerSet *get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> eqFlags, MLIRContext *context);
|
||||
typedef detail::IntegerSetStorage ImplType;
|
||||
|
||||
unsigned getNumDims() { return dimCount; }
|
||||
unsigned getNumSymbols() { return symbolCount; }
|
||||
unsigned getNumOperands() { return dimCount + symbolCount; }
|
||||
unsigned getNumConstraints() { return numConstraints; }
|
||||
explicit IntegerSet(ImplType *set = nullptr) : set(set) {}
|
||||
|
||||
ArrayRef<AffineExpr> getConstraints() { return constraints; }
|
||||
static IntegerSet get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> eqFlags, MLIRContext *context);
|
||||
|
||||
AffineExpr getConstraint(unsigned idx) { return getConstraints()[idx]; }
|
||||
explicit operator bool() { return set; }
|
||||
bool operator==(IntegerSet other) const { return set == other.set; }
|
||||
|
||||
unsigned getNumDims() const;
|
||||
unsigned getNumSymbols() const;
|
||||
unsigned getNumOperands() const;
|
||||
unsigned getNumConstraints() const;
|
||||
|
||||
ArrayRef<AffineExpr> getConstraints() const;
|
||||
|
||||
AffineExpr getConstraint(unsigned idx) const;
|
||||
|
||||
/// Returns the equality bits, which specify whether each of the constraints
|
||||
/// is an equality or inequality.
|
||||
ArrayRef<bool> getEqFlags() { return eqFlags; }
|
||||
ArrayRef<bool> getEqFlags() const;
|
||||
|
||||
/// Returns true if the idx^th constraint is an equality, false if it is an
|
||||
/// inequality.
|
||||
bool isEq(unsigned idx) { return getEqFlags()[idx]; }
|
||||
bool isEq(unsigned idx) const;
|
||||
|
||||
void print(raw_ostream &os);
|
||||
void dump();
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(IntegerSet arg);
|
||||
|
||||
private:
|
||||
IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints,
|
||||
ArrayRef<AffineExpr> constraints, ArrayRef<bool> eqFlags);
|
||||
|
||||
~IntegerSet() = delete;
|
||||
|
||||
unsigned dimCount;
|
||||
unsigned symbolCount;
|
||||
unsigned numConstraints;
|
||||
|
||||
/// Array of affine constraints: a constaint is either an equality
|
||||
/// (affine_expr == 0) or an inequality (affine_expr >= 0).
|
||||
ArrayRef<AffineExpr> constraints;
|
||||
|
||||
// Bits to check whether a constraint is an equality or an inequality.
|
||||
ArrayRef<bool> eqFlags;
|
||||
ImplType *set;
|
||||
};
|
||||
|
||||
// Make AffineExpr hashable.
|
||||
inline ::llvm::hash_code hash_value(IntegerSet arg) {
|
||||
return ::llvm::hash_value(arg.set);
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
namespace llvm {
|
||||
|
||||
// IntegerSet hash just like pointers
|
||||
template <> struct DenseMapInfo<mlir::IntegerSet> {
|
||||
static mlir::IntegerSet getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::IntegerSet getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::IntegerSet val) {
|
||||
return mlir::hash_value(val);
|
||||
}
|
||||
static bool isEqual(mlir::IntegerSet LHS, mlir::IntegerSet RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
#endif // MLIR_IR_INTEGER_SET_H
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#define MLIR_IR_STATEMENTS_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StmtBlock.h"
|
||||
|
@ -447,7 +448,7 @@ private:
|
|||
class IfStmt : public Statement {
|
||||
public:
|
||||
static IfStmt *create(Location *location, ArrayRef<MLValue *> operands,
|
||||
IntegerSet *set);
|
||||
IntegerSet set);
|
||||
~IfStmt();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -467,7 +468,7 @@ public:
|
|||
|
||||
const AffineCondition getCondition() const;
|
||||
|
||||
IntegerSet *getIntegerSet() const { return set; }
|
||||
IntegerSet getIntegerSet() const { return set; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
|
@ -528,12 +529,12 @@ private:
|
|||
IfClause *elseClause;
|
||||
|
||||
// The integer set capturing the conditional guard.
|
||||
IntegerSet *set;
|
||||
IntegerSet set;
|
||||
|
||||
// Condition operands.
|
||||
std::vector<StmtOperand> operands;
|
||||
|
||||
explicit IfStmt(Location *location, unsigned numOperands, IntegerSet *set);
|
||||
explicit IfStmt(Location *location, unsigned numOperands, IntegerSet set);
|
||||
};
|
||||
|
||||
/// AffineCondition represents a condition of the 'if' statement.
|
||||
|
@ -546,16 +547,15 @@ private:
|
|||
class AffineCondition {
|
||||
public:
|
||||
const IfStmt *getIfStmt() const { return &stmt; }
|
||||
IntegerSet *getSet() const { return set; }
|
||||
IntegerSet getSet() const { return set; }
|
||||
|
||||
private:
|
||||
// 'if' statement that contains this affine condition.
|
||||
const IfStmt &stmt;
|
||||
// Integer set for this affine condition.
|
||||
IntegerSet *set;
|
||||
IntegerSet set;
|
||||
|
||||
AffineCondition(const IfStmt &stmt, const IntegerSet *set)
|
||||
: stmt(stmt), set(const_cast<IntegerSet *>(set)) {}
|
||||
AffineCondition(const IfStmt &stmt, IntegerSet set) : stmt(stmt), set(set) {}
|
||||
|
||||
friend class IfStmt;
|
||||
};
|
||||
|
|
|
@ -199,8 +199,8 @@ AffineMap MutableAffineMap::getAffineMap() {
|
|||
return AffineMap::get(numDims, numSymbols, results, rangeSizes);
|
||||
}
|
||||
|
||||
MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
|
||||
: numDims(set->getNumDims()), numSymbols(set->getNumSymbols()),
|
||||
MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context)
|
||||
: numDims(set.getNumDims()), numSymbols(set.getNumSymbols()),
|
||||
context(context) {
|
||||
// TODO(bondhugula)
|
||||
}
|
||||
|
|
|
@ -109,7 +109,7 @@ HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
|
|||
ArrayRef<ArrayRef<AffineExpr>> lbs,
|
||||
ArrayRef<ArrayRef<AffineExpr>> ubs,
|
||||
MLIRContext *context,
|
||||
IntegerSet *symbolContext)
|
||||
IntegerSet symbolContext)
|
||||
: context(symbolContext ? MutableIntegerSet(symbolContext, context)
|
||||
: MutableIntegerSet(numDims, numSymbols, context)) {
|
||||
unsigned d = 0;
|
||||
|
|
|
@ -78,7 +78,7 @@ public:
|
|||
|
||||
ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
|
||||
|
||||
int getIntegerSetId(IntegerSet *integerSet) const {
|
||||
int getIntegerSetId(IntegerSet integerSet) const {
|
||||
auto it = integerSetIds.find(integerSet);
|
||||
if (it == integerSetIds.end()) {
|
||||
return -1;
|
||||
|
@ -86,7 +86,7 @@ public:
|
|||
return it->second;
|
||||
}
|
||||
|
||||
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
|
||||
ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; }
|
||||
|
||||
private:
|
||||
void recordAffineMapReference(AffineMap affineMap) {
|
||||
|
@ -96,7 +96,7 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
void recordIntegerSetReference(IntegerSet *integerSet) {
|
||||
void recordIntegerSetReference(IntegerSet integerSet) {
|
||||
if (integerSetIds.count(integerSet) == 0) {
|
||||
integerSetIds[integerSet] = integerSetsById.size();
|
||||
integerSetsById.push_back(integerSet);
|
||||
|
@ -131,8 +131,8 @@ private:
|
|||
DenseMap<AffineMap, int> affineMapIds;
|
||||
std::vector<AffineMap> affineMapsById;
|
||||
|
||||
DenseMap<IntegerSet *, int> integerSetIds;
|
||||
std::vector<IntegerSet *> integerSetsById;
|
||||
DenseMap<IntegerSet, int> integerSetIds;
|
||||
std::vector<IntegerSet> integerSetsById;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -280,7 +280,7 @@ public:
|
|||
void printAffineMap(AffineMap map);
|
||||
void printAffineExpr(AffineExpr expr);
|
||||
void printAffineConstraint(AffineExpr expr, bool isEq);
|
||||
void printIntegerSet(IntegerSet *set);
|
||||
void printIntegerSet(IntegerSet set);
|
||||
|
||||
protected:
|
||||
raw_ostream &os;
|
||||
|
@ -294,7 +294,7 @@ protected:
|
|||
void printAffineMapId(int affineMapId) const;
|
||||
void printAffineMapReference(AffineMap affineMap);
|
||||
void printIntegerSetId(int integerSetId) const;
|
||||
void printIntegerSetReference(IntegerSet *integerSet);
|
||||
void printIntegerSetReference(IntegerSet integerSet);
|
||||
|
||||
/// This enum is used to represent the binding stength of the enclosing
|
||||
/// context that an AffineExprStorage is being printed in, so we can
|
||||
|
@ -341,14 +341,14 @@ void ModulePrinter::printIntegerSetId(int integerSetId) const {
|
|||
os << "@@set" << integerSetId;
|
||||
}
|
||||
|
||||
void ModulePrinter::printIntegerSetReference(IntegerSet *integerSet) {
|
||||
void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
|
||||
int setId;
|
||||
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
|
||||
// The set will be printed at top of module; so print reference to its id.
|
||||
printIntegerSetId(setId);
|
||||
} else {
|
||||
// Set not in module state so print inline.
|
||||
integerSet->print(os);
|
||||
integerSet.print(os);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,7 +362,7 @@ void ModulePrinter::print(const Module *module) {
|
|||
for (const auto &set : state.getIntegerSetIds()) {
|
||||
printIntegerSetId(state.getIntegerSetId(set));
|
||||
os << " = ";
|
||||
set->print(os);
|
||||
set.print(os);
|
||||
os << '\n';
|
||||
}
|
||||
for (auto const &fn : *module)
|
||||
|
@ -729,35 +729,35 @@ void ModulePrinter::printAffineMap(AffineMap map) {
|
|||
os << ')';
|
||||
}
|
||||
|
||||
void ModulePrinter::printIntegerSet(IntegerSet *set) {
|
||||
void ModulePrinter::printIntegerSet(IntegerSet set) {
|
||||
// Dimension identifiers.
|
||||
os << '(';
|
||||
for (unsigned i = 1; i < set->getNumDims(); ++i)
|
||||
for (unsigned i = 1; i < set.getNumDims(); ++i)
|
||||
os << 'd' << i - 1 << ", ";
|
||||
if (set->getNumDims() >= 1)
|
||||
os << 'd' << set->getNumDims() - 1;
|
||||
if (set.getNumDims() >= 1)
|
||||
os << 'd' << set.getNumDims() - 1;
|
||||
os << ')';
|
||||
|
||||
// Symbolic identifiers.
|
||||
if (set->getNumSymbols() != 0) {
|
||||
if (set.getNumSymbols() != 0) {
|
||||
os << '[';
|
||||
for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i)
|
||||
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
|
||||
os << 's' << i << ", ";
|
||||
if (set->getNumSymbols() >= 1)
|
||||
os << 's' << set->getNumSymbols() - 1;
|
||||
if (set.getNumSymbols() >= 1)
|
||||
os << 's' << set.getNumSymbols() - 1;
|
||||
os << ']';
|
||||
}
|
||||
|
||||
// Print constraints.
|
||||
os << " : (";
|
||||
auto numConstraints = set->getNumConstraints();
|
||||
auto numConstraints = set.getNumConstraints();
|
||||
for (int i = 1; i < numConstraints; ++i) {
|
||||
printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1));
|
||||
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
|
||||
os << ", ";
|
||||
}
|
||||
if (numConstraints >= 1)
|
||||
printAffineConstraint(set->getConstraint(numConstraints - 1),
|
||||
set->isEq(numConstraints - 1));
|
||||
printAffineConstraint(set.getConstraint(numConstraints - 1),
|
||||
set.isEq(numConstraints - 1));
|
||||
os << ')';
|
||||
}
|
||||
|
||||
|
@ -867,7 +867,7 @@ public:
|
|||
void printAffineMap(AffineMap map) {
|
||||
return ModulePrinter::printAffineMapReference(map);
|
||||
}
|
||||
void printIntegerSet(IntegerSet *set) {
|
||||
void printIntegerSet(IntegerSet set) {
|
||||
return ModulePrinter::printIntegerSetReference(set);
|
||||
}
|
||||
void printAffineExpr(AffineExpr expr) {
|
||||
|
@ -1474,9 +1474,9 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
|||
|
||||
void MLFunctionPrinter::print(const IfStmt *stmt) {
|
||||
os.indent(numSpaces) << "if ";
|
||||
IntegerSet *set = stmt->getIntegerSet();
|
||||
IntegerSet set = stmt->getIntegerSet();
|
||||
printIntegerSetReference(set);
|
||||
printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims());
|
||||
printDimAndSymbolList(stmt->getStmtOperands(), set.getNumDims());
|
||||
os << " {\n";
|
||||
print(stmt->getThen());
|
||||
os.indent(numSpaces) << "}";
|
||||
|
@ -1514,7 +1514,7 @@ void AffineMap::dump() const {
|
|||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void IntegerSet::dump() {
|
||||
void IntegerSet::dump() const {
|
||||
print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
|
@ -1534,9 +1534,9 @@ void AffineMap::print(raw_ostream &os) const {
|
|||
ModulePrinter(os, state).printAffineMap(*this);
|
||||
}
|
||||
|
||||
void IntegerSet::print(raw_ostream &os) {
|
||||
void IntegerSet::print(raw_ostream &os) const {
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).printIntegerSet(this);
|
||||
ModulePrinter(os, state).printIntegerSet(*this);
|
||||
}
|
||||
|
||||
void SSAValue::print(raw_ostream &os) const {
|
||||
|
|
|
@ -171,9 +171,9 @@ AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
|
|||
return mlir::getAffineConstantExpr(constant, context);
|
||||
}
|
||||
|
||||
IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> isEq) {
|
||||
IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> isEq) {
|
||||
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
|
||||
}
|
||||
|
||||
|
@ -299,7 +299,7 @@ ForStmt *MLFuncBuilder::createFor(Location *location, int64_t lb, int64_t ub,
|
|||
}
|
||||
|
||||
IfStmt *MLFuncBuilder::createIf(Location *location,
|
||||
ArrayRef<MLValue *> operands, IntegerSet *set) {
|
||||
ArrayRef<MLValue *> operands, IntegerSet set) {
|
||||
auto *stmt = IfStmt::create(location, operands, set);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
//===- IntegerSet.cpp - MLIR Integer Set class ----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
|
@ -17,13 +16,31 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "IntegerSetDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount,
|
||||
unsigned numConstraints,
|
||||
ArrayRef<AffineExpr> constraints, ArrayRef<bool> eqFlags)
|
||||
: dimCount(dimCount), symbolCount(symbolCount),
|
||||
numConstraints(numConstraints), constraints(constraints),
|
||||
eqFlags(eqFlags) {}
|
||||
unsigned IntegerSet::getNumDims() const { return set->dimCount; }
|
||||
unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
|
||||
unsigned IntegerSet::getNumOperands() const {
|
||||
return set->dimCount + set->symbolCount;
|
||||
}
|
||||
unsigned IntegerSet::getNumConstraints() const { return set->numConstraints; }
|
||||
|
||||
ArrayRef<AffineExpr> IntegerSet::getConstraints() const {
|
||||
return set->constraints;
|
||||
}
|
||||
|
||||
AffineExpr IntegerSet::getConstraint(unsigned idx) const {
|
||||
return getConstraints()[idx];
|
||||
}
|
||||
|
||||
/// Returns the equality bits, which specify whether each of the constraints
|
||||
/// is an equality or inequality.
|
||||
ArrayRef<bool> IntegerSet::getEqFlags() const { return set->eqFlags; }
|
||||
|
||||
/// Returns true if the idx^th constraint is an equality, false if it is an
|
||||
/// inequality.
|
||||
bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; }
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
//===- IntegerSetDetail.h - MLIR IntegerSet storage details -----*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This holds implementation details of IntegerSet.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef INTEGERSETDETAIL_H_
|
||||
#define INTEGERSETDETAIL_H_
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
struct IntegerSetStorage {
|
||||
unsigned dimCount;
|
||||
unsigned symbolCount;
|
||||
unsigned numConstraints;
|
||||
|
||||
/// Array of affine constraints: a constraint is either an equality
|
||||
/// (affine_expr == 0) or an inequality (affine_expr >= 0).
|
||||
ArrayRef<AffineExpr> constraints;
|
||||
|
||||
// Bits to check whether a constraint is an equality or an inequality.
|
||||
ArrayRef<bool> eqFlags;
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
#endif // INTEGERSETDETAIL_H_
|
|
@ -19,6 +19,7 @@
|
|||
#include "AffineExprDetail.h"
|
||||
#include "AffineMapDetail.h"
|
||||
#include "AttributeListStorage.h"
|
||||
#include "IntegerSetDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -1126,21 +1127,24 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
|
|||
// But they aren't uniqued like AffineMap's; there isn't an advantage to.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> eqFlags, MLIRContext *context) {
|
||||
IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
|
||||
ArrayRef<AffineExpr> constraints,
|
||||
ArrayRef<bool> eqFlags, MLIRContext *context) {
|
||||
assert(eqFlags.size() == constraints.size());
|
||||
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Allocate them into the bump pointer.
|
||||
auto *res = impl.allocator.Allocate<IntegerSet>();
|
||||
auto *res = impl.allocator.Allocate<IntegerSetStorage>();
|
||||
|
||||
// Copy the equalities and inequalities into the bump pointer.
|
||||
constraints = impl.copyInto(ArrayRef<AffineExpr>(constraints));
|
||||
eqFlags = impl.copyInto(ArrayRef<bool>(eqFlags));
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
return new (res) IntegerSet(dimCount, symbolCount, constraints.size(),
|
||||
constraints, eqFlags);
|
||||
res = new (res) IntegerSetStorage{dimCount, symbolCount,
|
||||
static_cast<unsigned>(constraints.size()),
|
||||
constraints, eqFlags};
|
||||
|
||||
return IntegerSet(res);
|
||||
}
|
||||
|
|
|
@ -439,7 +439,7 @@ bool ForStmt::constantFoldBound(bool lower) {
|
|||
// IfStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set)
|
||||
IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet set)
|
||||
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
|
||||
set(set) {
|
||||
operands.reserve(numOperands);
|
||||
|
@ -454,9 +454,9 @@ IfStmt::~IfStmt() {
|
|||
}
|
||||
|
||||
IfStmt *IfStmt::create(Location *location, ArrayRef<MLValue *> operands,
|
||||
IntegerSet *set) {
|
||||
IntegerSet set) {
|
||||
unsigned numOperands = operands.size();
|
||||
assert(numOperands == set->getNumOperands() &&
|
||||
assert(numOperands == set.getNumOperands() &&
|
||||
"operand cound does not match the integer set operand count");
|
||||
|
||||
IfStmt *stmt = new IfStmt(location, numOperands, set);
|
||||
|
|
|
@ -72,7 +72,7 @@ public:
|
|||
llvm::StringMap<AffineMap> affineMapDefinitions;
|
||||
|
||||
// A map from integer set identifier to IntegerSet.
|
||||
llvm::StringMap<IntegerSet *> integerSetDefinitions;
|
||||
llvm::StringMap<IntegerSet> integerSetDefinitions;
|
||||
|
||||
// This keeps track of all forward references to functions along with the
|
||||
// temporary function used to represent them.
|
||||
|
@ -202,8 +202,8 @@ public:
|
|||
// Polyhedral structures.
|
||||
AffineMap parseAffineMapInline();
|
||||
AffineMap parseAffineMapReference();
|
||||
IntegerSet *parseIntegerSetInline();
|
||||
IntegerSet *parseIntegerSetReference();
|
||||
IntegerSet parseIntegerSetInline();
|
||||
IntegerSet parseIntegerSetReference();
|
||||
|
||||
private:
|
||||
// The Parser is subclassed and reinstantiated. Do not add additional
|
||||
|
@ -866,7 +866,7 @@ public:
|
|||
explicit AffineParser(ParserState &state) : Parser(state) {}
|
||||
|
||||
AffineMap parseAffineMapInline();
|
||||
IntegerSet *parseIntegerSetInline();
|
||||
IntegerSet parseIntegerSetInline();
|
||||
|
||||
private:
|
||||
// Binary affine op parsing.
|
||||
|
@ -2522,23 +2522,23 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
|
|||
/// affine-constraint-conjunction ::= /*empty*/
|
||||
/// | affine-constraint (`,` affine-constraint)*
|
||||
///
|
||||
IntegerSet *AffineParser::parseIntegerSetInline() {
|
||||
IntegerSet AffineParser::parseIntegerSetInline() {
|
||||
unsigned numDims = 0, numSymbols = 0;
|
||||
|
||||
// List of dimensional identifiers.
|
||||
if (parseDimIdList(numDims))
|
||||
return nullptr;
|
||||
return IntegerSet();
|
||||
|
||||
// Symbols are optional.
|
||||
if (getToken().is(Token::l_square)) {
|
||||
if (parseSymbolIdList(numSymbols))
|
||||
return nullptr;
|
||||
return IntegerSet();
|
||||
}
|
||||
|
||||
if (parseToken(Token::colon, "expected ':' or '['") ||
|
||||
parseToken(Token::l_paren,
|
||||
"expected '(' at start of integer set constraint list"))
|
||||
return nullptr;
|
||||
return IntegerSet();
|
||||
|
||||
SmallVector<AffineExpr, 4> constraints;
|
||||
SmallVector<bool, 4> isEqs;
|
||||
|
@ -2557,13 +2557,13 @@ IntegerSet *AffineParser::parseIntegerSetInline() {
|
|||
// Grammar: affine-constraint-conjunct ::= `(` affine-constraint (`,`
|
||||
// affine-constraint)* `)
|
||||
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
|
||||
return nullptr;
|
||||
return IntegerSet();
|
||||
|
||||
// Parsed a valid integer set.
|
||||
return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
|
||||
}
|
||||
|
||||
IntegerSet *Parser::parseIntegerSetInline() {
|
||||
IntegerSet Parser::parseIntegerSetInline() {
|
||||
return AffineParser(state).parseIntegerSetInline();
|
||||
}
|
||||
|
||||
|
@ -2571,14 +2571,14 @@ IntegerSet *Parser::parseIntegerSetInline() {
|
|||
/// integer-set ::= integer-set-id | integer-set-inline
|
||||
/// integer-set-id ::= `@@` suffix-id
|
||||
///
|
||||
IntegerSet *Parser::parseIntegerSetReference() {
|
||||
IntegerSet Parser::parseIntegerSetReference() {
|
||||
// TODO: change '@@' integer set prefix to '#'.
|
||||
if (getToken().is(Token::double_at_identifier)) {
|
||||
// Parse integer set identifier and verify that it exists.
|
||||
StringRef integerSetId = getTokenSpelling().drop_front(2);
|
||||
if (getState().integerSetDefinitions.count(integerSetId) == 0)
|
||||
return (emitError("undefined integer set id '" + integerSetId + "'"),
|
||||
nullptr);
|
||||
IntegerSet());
|
||||
consumeToken(Token::double_at_identifier);
|
||||
return getState().integerSetDefinitions[integerSetId];
|
||||
}
|
||||
|
@ -2597,12 +2597,12 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
auto loc = getToken().getLoc();
|
||||
consumeToken(Token::kw_if);
|
||||
|
||||
IntegerSet *set = parseIntegerSetReference();
|
||||
IntegerSet set = parseIntegerSetReference();
|
||||
if (!set)
|
||||
return ParseFailure;
|
||||
|
||||
SmallVector<MLValue *, 4> operands;
|
||||
if (parseDimAndSymbolList(operands, set->getNumDims(), set->getNumOperands(),
|
||||
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
|
||||
"integer set"))
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -2757,7 +2757,7 @@ ParseResult ModuleParser::parseIntegerSetDef() {
|
|||
StringRef integerSetId = getTokenSpelling().drop_front(2);
|
||||
|
||||
// Check for redefinitions (a default entry is created if one doesn't exist)
|
||||
auto *&entry = getState().integerSetDefinitions[integerSetId];
|
||||
auto &entry = getState().integerSetDefinitions[integerSetId];
|
||||
if (entry)
|
||||
return emitError("redefinition of integer set id '" + integerSetId + "'");
|
||||
|
||||
|
|
Loading…
Reference in New Issue