From b04f881dcb8ec081031c3a64c937c46b05776a96 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 10 Oct 2018 09:45:59 -0700 Subject: [PATCH] [MLIR] IntegerSet value type This CL applies the same pattern as AffineMap to IntegerSet: a simple struct that acts as the storage is allocated in the bump pointer. The IntegerSet is immutable and accessed everywhere by value. Note that unlike AffineMap, it is not possible to remove the MLIRContext parameter when constructing an IntegerSet for now. One possible way to achieve this would be to add an enum to distinguish between the mathematically empty set, the universe set and other sets. This is left for future discussion. PiperOrigin-RevId: 216545361 --- mlir/include/mlir/Analysis/AffineStructures.h | 2 +- .../mlir/Analysis/HyperRectangularSet.h | 3 +- mlir/include/mlir/IR/AffineExpr.h | 1 + mlir/include/mlir/IR/Builders.h | 8 +- mlir/include/mlir/IR/IntegerSet.h | 84 ++++++++++++------- mlir/include/mlir/IR/Statements.h | 16 ++-- mlir/lib/Analysis/AffineStructures.cpp | 4 +- mlir/lib/Analysis/HyperRectangularSet.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 56 ++++++------- mlir/lib/IR/Builders.cpp | 8 +- mlir/lib/IR/IntegerSet.cpp | 31 +++++-- mlir/lib/IR/IntegerSetDetail.h | 46 ++++++++++ mlir/lib/IR/MLIRContext.cpp | 16 ++-- mlir/lib/IR/Statement.cpp | 6 +- mlir/lib/Parser/Parser.cpp | 30 +++---- 15 files changed, 204 insertions(+), 109 deletions(-) create mode 100644 mlir/lib/IR/IntegerSetDetail.h diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 6533bc0865a7..101a00004af2 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -78,7 +78,7 @@ private: /// A mutable integer set. Its affine expressions are however unique. struct MutableIntegerSet { public: - MutableIntegerSet(IntegerSet *set, MLIRContext *context); + MutableIntegerSet(IntegerSet set, MLIRContext *context); /// Create a universal set (no constraints). MutableIntegerSet(unsigned numDims, unsigned numSymbols, diff --git a/mlir/include/mlir/Analysis/HyperRectangularSet.h b/mlir/include/mlir/Analysis/HyperRectangularSet.h index ad2b2560dc4f..dbccf1c36b13 100644 --- a/mlir/include/mlir/Analysis/HyperRectangularSet.h +++ b/mlir/include/mlir/Analysis/HyperRectangularSet.h @@ -26,6 +26,7 @@ #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/IntegerSet.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" @@ -95,7 +96,7 @@ public: HyperRectangularSet(unsigned numDims, unsigned numSymbols, ArrayRef> lbs, ArrayRef> ubs, MLIRContext *context, - IntegerSet *symbolContext = nullptr); + IntegerSet symbolContext = IntegerSet()); unsigned getNumDims() const { return numDims; } unsigned getNumSymbols() const { return numSymbols; } diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 301ca7372dfa..d60f4e165934 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -84,6 +84,7 @@ public: } bool operator==(AffineExpr other) const { return expr == other.expr; } + bool operator!=(AffineExpr other) const { return !(*this == other); } explicit operator bool() const { return expr; } bool operator!() const { return expr == nullptr; } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index c75ceb4eb31c..5eb689dc8e10 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -132,9 +132,9 @@ public: AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); // Integer set. - IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef isEq); + IntegerSet getIntegerSet(unsigned dimCount, unsigned symbolCount, + ArrayRef constraints, + ArrayRef isEq); // TODO: Helpers for affine map/exprs, etc. protected: MLIRContext *context; @@ -402,7 +402,7 @@ public: /// Creates if statement. IfStmt *createIf(Location *location, ArrayRef operands, - IntegerSet *set); + IntegerSet set); private: StmtBlock *block = nullptr; diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h index cb7eec841ced..e79cb47ebce3 100644 --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -38,55 +38,81 @@ namespace mlir { +namespace detail { +struct IntegerSetStorage; +} + class MLIRContext; /// An integer set representing a conjunction of affine equalities and /// inequalities. An integer set in the IR is immutable like the affine map, but /// integer sets are not unique'd. The affine expressions that make up the -/// equalities and inequalities of an integer set are themselves unique. +/// equalities and inequalities of an integer set are themselves unique and live +/// in the bump allocator. class IntegerSet { public: - static IntegerSet *get(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef eqFlags, MLIRContext *context); + typedef detail::IntegerSetStorage ImplType; - unsigned getNumDims() { return dimCount; } - unsigned getNumSymbols() { return symbolCount; } - unsigned getNumOperands() { return dimCount + symbolCount; } - unsigned getNumConstraints() { return numConstraints; } + explicit IntegerSet(ImplType *set = nullptr) : set(set) {} - ArrayRef getConstraints() { return constraints; } + static IntegerSet get(unsigned dimCount, unsigned symbolCount, + ArrayRef constraints, + ArrayRef eqFlags, MLIRContext *context); - AffineExpr getConstraint(unsigned idx) { return getConstraints()[idx]; } + explicit operator bool() { return set; } + bool operator==(IntegerSet other) const { return set == other.set; } + + unsigned getNumDims() const; + unsigned getNumSymbols() const; + unsigned getNumOperands() const; + unsigned getNumConstraints() const; + + ArrayRef getConstraints() const; + + AffineExpr getConstraint(unsigned idx) const; /// Returns the equality bits, which specify whether each of the constraints /// is an equality or inequality. - ArrayRef getEqFlags() { return eqFlags; } + ArrayRef getEqFlags() const; /// Returns true if the idx^th constraint is an equality, false if it is an /// inequality. - bool isEq(unsigned idx) { return getEqFlags()[idx]; } + bool isEq(unsigned idx) const; - void print(raw_ostream &os); - void dump(); + void print(raw_ostream &os) const; + void dump() const; + + friend ::llvm::hash_code hash_value(IntegerSet arg); private: - IntegerSet(unsigned dimCount, unsigned symbolCount, unsigned numConstraints, - ArrayRef constraints, ArrayRef eqFlags); - - ~IntegerSet() = delete; - - unsigned dimCount; - unsigned symbolCount; - unsigned numConstraints; - - /// Array of affine constraints: a constaint is either an equality - /// (affine_expr == 0) or an inequality (affine_expr >= 0). - ArrayRef constraints; - - // Bits to check whether a constraint is an equality or an inequality. - ArrayRef eqFlags; + ImplType *set; }; +// Make AffineExpr hashable. +inline ::llvm::hash_code hash_value(IntegerSet arg) { + return ::llvm::hash_value(arg.set); +} + } // end namespace mlir +namespace llvm { + +// IntegerSet hash just like pointers +template <> struct DenseMapInfo { + static mlir::IntegerSet getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::IntegerSet(static_cast(pointer)); + } + static mlir::IntegerSet getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::IntegerSet(static_cast(pointer)); + } + static unsigned getHashValue(mlir::IntegerSet val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::IntegerSet LHS, mlir::IntegerSet RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm #endif // MLIR_IR_INTEGER_SET_H diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 0087390481d5..b1679f2ea0b2 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -23,6 +23,7 @@ #define MLIR_IR_STATEMENTS_H #include "mlir/IR/AffineMap.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLValue.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StmtBlock.h" @@ -447,7 +448,7 @@ private: class IfStmt : public Statement { public: static IfStmt *create(Location *location, ArrayRef operands, - IntegerSet *set); + IntegerSet set); ~IfStmt(); //===--------------------------------------------------------------------===// @@ -467,7 +468,7 @@ public: const AffineCondition getCondition() const; - IntegerSet *getIntegerSet() const { return set; } + IntegerSet getIntegerSet() const { return set; } //===--------------------------------------------------------------------===// // Operands @@ -528,12 +529,12 @@ private: IfClause *elseClause; // The integer set capturing the conditional guard. - IntegerSet *set; + IntegerSet set; // Condition operands. std::vector operands; - explicit IfStmt(Location *location, unsigned numOperands, IntegerSet *set); + explicit IfStmt(Location *location, unsigned numOperands, IntegerSet set); }; /// AffineCondition represents a condition of the 'if' statement. @@ -546,16 +547,15 @@ private: class AffineCondition { public: const IfStmt *getIfStmt() const { return &stmt; } - IntegerSet *getSet() const { return set; } + IntegerSet getSet() const { return set; } private: // 'if' statement that contains this affine condition. const IfStmt &stmt; // Integer set for this affine condition. - IntegerSet *set; + IntegerSet set; - AffineCondition(const IfStmt &stmt, const IntegerSet *set) - : stmt(stmt), set(const_cast(set)) {} + AffineCondition(const IfStmt &stmt, IntegerSet set) : stmt(stmt), set(set) {} friend class IfStmt; }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index e1d0d357a231..5460678e8c1e 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -199,8 +199,8 @@ AffineMap MutableAffineMap::getAffineMap() { return AffineMap::get(numDims, numSymbols, results, rangeSizes); } -MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context) - : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()), +MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) + : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()), context(context) { // TODO(bondhugula) } diff --git a/mlir/lib/Analysis/HyperRectangularSet.cpp b/mlir/lib/Analysis/HyperRectangularSet.cpp index bd1361b1d93a..90272c4f9056 100644 --- a/mlir/lib/Analysis/HyperRectangularSet.cpp +++ b/mlir/lib/Analysis/HyperRectangularSet.cpp @@ -109,7 +109,7 @@ HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols, ArrayRef> lbs, ArrayRef> ubs, MLIRContext *context, - IntegerSet *symbolContext) + IntegerSet symbolContext) : context(symbolContext ? MutableIntegerSet(symbolContext, context) : MutableIntegerSet(numDims, numSymbols, context)) { unsigned d = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a9d9c22d5afe..5263c80e8e2a 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -78,7 +78,7 @@ public: ArrayRef getAffineMapIds() const { return affineMapsById; } - int getIntegerSetId(IntegerSet *integerSet) const { + int getIntegerSetId(IntegerSet integerSet) const { auto it = integerSetIds.find(integerSet); if (it == integerSetIds.end()) { return -1; @@ -86,7 +86,7 @@ public: return it->second; } - ArrayRef getIntegerSetIds() const { return integerSetsById; } + ArrayRef getIntegerSetIds() const { return integerSetsById; } private: void recordAffineMapReference(AffineMap affineMap) { @@ -96,7 +96,7 @@ private: } } - void recordIntegerSetReference(IntegerSet *integerSet) { + void recordIntegerSetReference(IntegerSet integerSet) { if (integerSetIds.count(integerSet) == 0) { integerSetIds[integerSet] = integerSetsById.size(); integerSetsById.push_back(integerSet); @@ -131,8 +131,8 @@ private: DenseMap affineMapIds; std::vector affineMapsById; - DenseMap integerSetIds; - std::vector integerSetsById; + DenseMap integerSetIds; + std::vector integerSetsById; }; } // end anonymous namespace @@ -280,7 +280,7 @@ public: void printAffineMap(AffineMap map); void printAffineExpr(AffineExpr expr); void printAffineConstraint(AffineExpr expr, bool isEq); - void printIntegerSet(IntegerSet *set); + void printIntegerSet(IntegerSet set); protected: raw_ostream &os; @@ -294,7 +294,7 @@ protected: void printAffineMapId(int affineMapId) const; void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; - void printIntegerSetReference(IntegerSet *integerSet); + void printIntegerSetReference(IntegerSet integerSet); /// This enum is used to represent the binding stength of the enclosing /// context that an AffineExprStorage is being printed in, so we can @@ -341,14 +341,14 @@ void ModulePrinter::printIntegerSetId(int integerSetId) const { os << "@@set" << integerSetId; } -void ModulePrinter::printIntegerSetReference(IntegerSet *integerSet) { +void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) { int setId; if ((setId = state.getIntegerSetId(integerSet)) >= 0) { // The set will be printed at top of module; so print reference to its id. printIntegerSetId(setId); } else { // Set not in module state so print inline. - integerSet->print(os); + integerSet.print(os); } } @@ -362,7 +362,7 @@ void ModulePrinter::print(const Module *module) { for (const auto &set : state.getIntegerSetIds()) { printIntegerSetId(state.getIntegerSetId(set)); os << " = "; - set->print(os); + set.print(os); os << '\n'; } for (auto const &fn : *module) @@ -729,35 +729,35 @@ void ModulePrinter::printAffineMap(AffineMap map) { os << ')'; } -void ModulePrinter::printIntegerSet(IntegerSet *set) { +void ModulePrinter::printIntegerSet(IntegerSet set) { // Dimension identifiers. os << '('; - for (unsigned i = 1; i < set->getNumDims(); ++i) + for (unsigned i = 1; i < set.getNumDims(); ++i) os << 'd' << i - 1 << ", "; - if (set->getNumDims() >= 1) - os << 'd' << set->getNumDims() - 1; + if (set.getNumDims() >= 1) + os << 'd' << set.getNumDims() - 1; os << ')'; // Symbolic identifiers. - if (set->getNumSymbols() != 0) { + if (set.getNumSymbols() != 0) { os << '['; - for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i) + for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) os << 's' << i << ", "; - if (set->getNumSymbols() >= 1) - os << 's' << set->getNumSymbols() - 1; + if (set.getNumSymbols() >= 1) + os << 's' << set.getNumSymbols() - 1; os << ']'; } // Print constraints. os << " : ("; - auto numConstraints = set->getNumConstraints(); + auto numConstraints = set.getNumConstraints(); for (int i = 1; i < numConstraints; ++i) { - printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1)); + printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); os << ", "; } if (numConstraints >= 1) - printAffineConstraint(set->getConstraint(numConstraints - 1), - set->isEq(numConstraints - 1)); + printAffineConstraint(set.getConstraint(numConstraints - 1), + set.isEq(numConstraints - 1)); os << ')'; } @@ -867,7 +867,7 @@ public: void printAffineMap(AffineMap map) { return ModulePrinter::printAffineMapReference(map); } - void printIntegerSet(IntegerSet *set) { + void printIntegerSet(IntegerSet set) { return ModulePrinter::printIntegerSetReference(set); } void printAffineExpr(AffineExpr expr) { @@ -1474,9 +1474,9 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) { void MLFunctionPrinter::print(const IfStmt *stmt) { os.indent(numSpaces) << "if "; - IntegerSet *set = stmt->getIntegerSet(); + IntegerSet set = stmt->getIntegerSet(); printIntegerSetReference(set); - printDimAndSymbolList(stmt->getStmtOperands(), set->getNumDims()); + printDimAndSymbolList(stmt->getStmtOperands(), set.getNumDims()); os << " {\n"; print(stmt->getThen()); os.indent(numSpaces) << "}"; @@ -1514,7 +1514,7 @@ void AffineMap::dump() const { llvm::errs() << "\n"; } -void IntegerSet::dump() { +void IntegerSet::dump() const { print(llvm::errs()); llvm::errs() << "\n"; } @@ -1534,9 +1534,9 @@ void AffineMap::print(raw_ostream &os) const { ModulePrinter(os, state).printAffineMap(*this); } -void IntegerSet::print(raw_ostream &os) { +void IntegerSet::print(raw_ostream &os) const { ModuleState state(/*no context is known*/ nullptr); - ModulePrinter(os, state).printIntegerSet(this); + ModulePrinter(os, state).printIntegerSet(*this); } void SSAValue::print(raw_ostream &os) const { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index c24af33d6788..8d1c02ef9ea3 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -171,9 +171,9 @@ AffineExpr Builder::getAffineConstantExpr(int64_t constant) { return mlir::getAffineConstantExpr(constant, context); } -IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef isEq) { +IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, + ArrayRef constraints, + ArrayRef isEq) { return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context); } @@ -299,7 +299,7 @@ ForStmt *MLFuncBuilder::createFor(Location *location, int64_t lb, int64_t ub, } IfStmt *MLFuncBuilder::createIf(Location *location, - ArrayRef operands, IntegerSet *set) { + ArrayRef operands, IntegerSet set) { auto *stmt = IfStmt::create(location, operands, set); block->getStatements().insert(insertPoint, stmt); return stmt; diff --git a/mlir/lib/IR/IntegerSet.cpp b/mlir/lib/IR/IntegerSet.cpp index cfd838516dbe..889bdd403af3 100644 --- a/mlir/lib/IR/IntegerSet.cpp +++ b/mlir/lib/IR/IntegerSet.cpp @@ -1,4 +1,3 @@ - //===- IntegerSet.cpp - MLIR Integer Set class ----------------------------===// // // Copyright 2019 The MLIR Authors. @@ -17,13 +16,31 @@ // ============================================================================= #include "mlir/IR/IntegerSet.h" +#include "IntegerSetDetail.h" #include "mlir/IR/AffineExpr.h" using namespace mlir; +using namespace mlir::detail; -IntegerSet::IntegerSet(unsigned dimCount, unsigned symbolCount, - unsigned numConstraints, - ArrayRef constraints, ArrayRef eqFlags) - : dimCount(dimCount), symbolCount(symbolCount), - numConstraints(numConstraints), constraints(constraints), - eqFlags(eqFlags) {} +unsigned IntegerSet::getNumDims() const { return set->dimCount; } +unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; } +unsigned IntegerSet::getNumOperands() const { + return set->dimCount + set->symbolCount; +} +unsigned IntegerSet::getNumConstraints() const { return set->numConstraints; } + +ArrayRef IntegerSet::getConstraints() const { + return set->constraints; +} + +AffineExpr IntegerSet::getConstraint(unsigned idx) const { + return getConstraints()[idx]; +} + +/// Returns the equality bits, which specify whether each of the constraints +/// is an equality or inequality. +ArrayRef IntegerSet::getEqFlags() const { return set->eqFlags; } + +/// Returns true if the idx^th constraint is an equality, false if it is an +/// inequality. +bool IntegerSet::isEq(unsigned idx) const { return getEqFlags()[idx]; } diff --git a/mlir/lib/IR/IntegerSetDetail.h b/mlir/lib/IR/IntegerSetDetail.h new file mode 100644 index 000000000000..59b3f87ec296 --- /dev/null +++ b/mlir/lib/IR/IntegerSetDetail.h @@ -0,0 +1,46 @@ +//===- IntegerSetDetail.h - MLIR IntegerSet storage details -----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of IntegerSet. +// +//===----------------------------------------------------------------------===// + +#ifndef INTEGERSETDETAIL_H_ +#define INTEGERSETDETAIL_H_ + +#include "mlir/IR/AffineExpr.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace detail { + +struct IntegerSetStorage { + unsigned dimCount; + unsigned symbolCount; + unsigned numConstraints; + + /// Array of affine constraints: a constraint is either an equality + /// (affine_expr == 0) or an inequality (affine_expr >= 0). + ArrayRef constraints; + + // Bits to check whether a constraint is an equality or an inequality. + ArrayRef eqFlags; +}; + +} // end namespace detail +} // end namespace mlir +#endif // INTEGERSETDETAIL_H_ diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 220ef709c94a..9394d6ea9cea 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -19,6 +19,7 @@ #include "AffineExprDetail.h" #include "AffineMapDetail.h" #include "AttributeListStorage.h" +#include "IntegerSetDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -1126,21 +1127,24 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { // But they aren't uniqued like AffineMap's; there isn't an advantage to. //===----------------------------------------------------------------------===// -IntegerSet *IntegerSet::get(unsigned dimCount, unsigned symbolCount, - ArrayRef constraints, - ArrayRef eqFlags, MLIRContext *context) { +IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, + ArrayRef constraints, + ArrayRef eqFlags, MLIRContext *context) { assert(eqFlags.size() == constraints.size()); auto &impl = context->getImpl(); // Allocate them into the bump pointer. - auto *res = impl.allocator.Allocate(); + auto *res = impl.allocator.Allocate(); // Copy the equalities and inequalities into the bump pointer. constraints = impl.copyInto(ArrayRef(constraints)); eqFlags = impl.copyInto(ArrayRef(eqFlags)); // Initialize the memory using placement new. - return new (res) IntegerSet(dimCount, symbolCount, constraints.size(), - constraints, eqFlags); + res = new (res) IntegerSetStorage{dimCount, symbolCount, + static_cast(constraints.size()), + constraints, eqFlags}; + + return IntegerSet(res); } diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 263cc595d291..e4cb3e5bd873 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -439,7 +439,7 @@ bool ForStmt::constantFoldBound(bool lower) { // IfStmt //===----------------------------------------------------------------------===// -IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet *set) +IfStmt::IfStmt(Location *location, unsigned numOperands, IntegerSet set) : Statement(Kind::If, location), thenClause(this), elseClause(nullptr), set(set) { operands.reserve(numOperands); @@ -454,9 +454,9 @@ IfStmt::~IfStmt() { } IfStmt *IfStmt::create(Location *location, ArrayRef operands, - IntegerSet *set) { + IntegerSet set) { unsigned numOperands = operands.size(); - assert(numOperands == set->getNumOperands() && + assert(numOperands == set.getNumOperands() && "operand cound does not match the integer set operand count"); IfStmt *stmt = new IfStmt(location, numOperands, set); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 51a1f6858612..7315d019075f 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -72,7 +72,7 @@ public: llvm::StringMap affineMapDefinitions; // A map from integer set identifier to IntegerSet. - llvm::StringMap integerSetDefinitions; + llvm::StringMap integerSetDefinitions; // This keeps track of all forward references to functions along with the // temporary function used to represent them. @@ -202,8 +202,8 @@ public: // Polyhedral structures. AffineMap parseAffineMapInline(); AffineMap parseAffineMapReference(); - IntegerSet *parseIntegerSetInline(); - IntegerSet *parseIntegerSetReference(); + IntegerSet parseIntegerSetInline(); + IntegerSet parseIntegerSetReference(); private: // The Parser is subclassed and reinstantiated. Do not add additional @@ -866,7 +866,7 @@ public: explicit AffineParser(ParserState &state) : Parser(state) {} AffineMap parseAffineMapInline(); - IntegerSet *parseIntegerSetInline(); + IntegerSet parseIntegerSetInline(); private: // Binary affine op parsing. @@ -2522,23 +2522,23 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { /// affine-constraint-conjunction ::= /*empty*/ /// | affine-constraint (`,` affine-constraint)* /// -IntegerSet *AffineParser::parseIntegerSetInline() { +IntegerSet AffineParser::parseIntegerSetInline() { unsigned numDims = 0, numSymbols = 0; // List of dimensional identifiers. if (parseDimIdList(numDims)) - return nullptr; + return IntegerSet(); // Symbols are optional. if (getToken().is(Token::l_square)) { if (parseSymbolIdList(numSymbols)) - return nullptr; + return IntegerSet(); } if (parseToken(Token::colon, "expected ':' or '['") || parseToken(Token::l_paren, "expected '(' at start of integer set constraint list")) - return nullptr; + return IntegerSet(); SmallVector constraints; SmallVector isEqs; @@ -2557,13 +2557,13 @@ IntegerSet *AffineParser::parseIntegerSetInline() { // Grammar: affine-constraint-conjunct ::= `(` affine-constraint (`,` // affine-constraint)* `) if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true)) - return nullptr; + return IntegerSet(); // Parsed a valid integer set. return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } -IntegerSet *Parser::parseIntegerSetInline() { +IntegerSet Parser::parseIntegerSetInline() { return AffineParser(state).parseIntegerSetInline(); } @@ -2571,14 +2571,14 @@ IntegerSet *Parser::parseIntegerSetInline() { /// integer-set ::= integer-set-id | integer-set-inline /// integer-set-id ::= `@@` suffix-id /// -IntegerSet *Parser::parseIntegerSetReference() { +IntegerSet Parser::parseIntegerSetReference() { // TODO: change '@@' integer set prefix to '#'. if (getToken().is(Token::double_at_identifier)) { // Parse integer set identifier and verify that it exists. StringRef integerSetId = getTokenSpelling().drop_front(2); if (getState().integerSetDefinitions.count(integerSetId) == 0) return (emitError("undefined integer set id '" + integerSetId + "'"), - nullptr); + IntegerSet()); consumeToken(Token::double_at_identifier); return getState().integerSetDefinitions[integerSetId]; } @@ -2597,12 +2597,12 @@ ParseResult MLFunctionParser::parseIfStmt() { auto loc = getToken().getLoc(); consumeToken(Token::kw_if); - IntegerSet *set = parseIntegerSetReference(); + IntegerSet set = parseIntegerSetReference(); if (!set) return ParseFailure; SmallVector operands; - if (parseDimAndSymbolList(operands, set->getNumDims(), set->getNumOperands(), + if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), "integer set")) return ParseFailure; @@ -2757,7 +2757,7 @@ ParseResult ModuleParser::parseIntegerSetDef() { StringRef integerSetId = getTokenSpelling().drop_front(2); // Check for redefinitions (a default entry is created if one doesn't exist) - auto *&entry = getState().integerSetDefinitions[integerSetId]; + auto &entry = getState().integerSetDefinitions[integerSetId]; if (entry) return emitError("redefinition of integer set id '" + integerSetId + "'");