[mlir][dataflow] Remove Lattice::isUninitialized().

Currently, for sparse analyses, we always store a `Optional<ValueT>` in each lattice element. When it's `None`, we consider the lattice element as `uninitialized`.

However:

* Not all lattices have an `uninitialized` state. For example, `Executable` and `PredecessorState` have default values so they are always initialized.

* In dense analyses, we don't have the concept of an `uninitialized` state.

Given these inconsistencies, this patch removes `Lattice::isUninitialized()`. Individual analysis states are now default-constructed. If the default state of an analysis can be considered as "uninitialized" then this analysis should implement the following logic:

* Special join rule: `join(uninitialized, any) == any`.

* Special bail out logic: if any of the input states is uninitialized, exit the transfer function early.

Depends On D132086

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D132800
This commit is contained in:
Zhixun Tan 2022-09-08 08:43:47 -07:00 committed by Jeff Niu
parent afa0ed33df
commit 47bf3e3812
14 changed files with 107 additions and 85 deletions

View File

@ -28,15 +28,24 @@ namespace dataflow {
/// This lattice value represents a known constant value of a lattice.
class ConstantValue {
public:
/// Construct a constant value as uninitialized.
explicit ConstantValue() = default;
/// Construct a constant value with a known constant.
ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
: constant(knownValue), dialect(dialect) {}
explicit ConstantValue(Attribute constant, Dialect *dialect)
: constant(constant), dialect(dialect) {}
/// Get the constant value. Returns null if no value was determined.
Attribute getConstantValue() const { return constant; }
Attribute getConstantValue() const {
assert(!isUninitialized());
return *constant;
}
/// Get the dialect instance that can be used to materialize the constant.
Dialect *getConstantDialect() const { return dialect; }
Dialect *getConstantDialect() const {
assert(!isUninitialized());
return dialect;
}
/// Compare the constant values.
bool operator==(const ConstantValue &rhs) const {
@ -46,21 +55,36 @@ public:
/// Print the constant value.
void print(raw_ostream &os) const;
/// The state where the constant value is uninitialized. This happens when the
/// state hasn't been set during the analysis.
static ConstantValue getUninitialized() { return ConstantValue{}; }
/// Whether the state is uninitialized.
bool isUninitialized() const { return !constant.has_value(); }
/// The state where the constant value is unknown.
static ConstantValue getUnknownConstant() { return {}; }
static ConstantValue getUnknownConstant() {
return ConstantValue{/*constant=*/nullptr, /*dialect=*/nullptr};
}
/// The union with another constant value is null if they are different, and
/// the same if they are the same.
static ConstantValue join(const ConstantValue &lhs,
const ConstantValue &rhs) {
return lhs == rhs ? lhs : ConstantValue();
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
if (lhs == rhs)
return lhs;
return getUnknownConstant();
}
private:
/// The constant value.
Attribute constant;
/// An dialect instance that can be used to materialize the constant.
Dialect *dialect;
Optional<Attribute> constant;
/// A dialect instance that can be used to materialize the constant.
Dialect *dialect = nullptr;
};
//===----------------------------------------------------------------------===//

View File

@ -38,9 +38,6 @@ class Executable : public AnalysisState {
public:
using AnalysisState::AnalysisState;
/// The state is initialized by default.
bool isUninitialized() const override { return false; }
/// Set the state of the program point to live.
ChangeResult setToLive();
@ -95,9 +92,6 @@ class PredecessorState : public AnalysisState {
public:
using AnalysisState::AnalysisState;
/// The state is initialized by default.
bool isUninitialized() const override { return false; }
/// Print the known predecessors.
void print(raw_ostream &os) const override;

View File

@ -30,10 +30,18 @@ public:
static IntegerValueRange getMaxRange(Value value);
/// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
IntegerValueRange(Optional<ConstantIntRanges> value = None)
: value(std::move(value)) {}
/// Whether the range is uninitialized. This happens when the state hasn't
/// been set during the analysis.
bool isUninitialized() const { return !value.has_value(); }
/// Get the known integer value range.
const ConstantIntRanges &getValue() const { return value; }
const ConstantIntRanges &getValue() const {
assert(!isUninitialized());
return *value;
}
/// Compare two ranges.
bool operator==(const IntegerValueRange &rhs) const {
@ -43,7 +51,11 @@ public:
/// Take the union of two ranges.
static IntegerValueRange join(const IntegerValueRange &lhs,
const IntegerValueRange &rhs) {
return lhs.value.rangeUnion(rhs.value);
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
}
/// Print the integer value range.
@ -51,7 +63,7 @@ public:
private:
/// The known integer value range.
ConstantIntRanges value;
Optional<ConstantIntRanges> value;
};
/// This lattice element represents the integer value range of an SSA value.

View File

@ -81,27 +81,17 @@ public:
/// Return the value held by this lattice. This requires that the value is
/// initialized.
ValueT &getValue() {
assert(!isUninitialized() && "expected known lattice element");
return *value;
}
ValueT &getValue() { return value; }
const ValueT &getValue() const {
return const_cast<Lattice<ValueT> *>(this)->getValue();
}
/// Returns true if the value of this lattice hasn't yet been initialized.
bool isUninitialized() const override { return !value.has_value(); }
/// Join the information contained in the 'rhs' lattice into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const AbstractSparseLattice &rhs) override {
const Lattice<ValueT> &rhsLattice =
static_cast<const Lattice<ValueT> &>(rhs);
// If rhs is uninitialized, there is nothing to do.
if (rhsLattice.isUninitialized())
return ChangeResult::NoChange;
// Join the rhs value into this lattice.
return join(rhsLattice.getValue());
}
@ -109,15 +99,9 @@ public:
/// Join the information contained in the 'rhs' value into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const ValueT &rhs) {
// If the current lattice is uninitialized, copy the rhs value.
if (isUninitialized()) {
value = rhs;
return ChangeResult::Change;
}
// Otherwise, join rhs with the current optimistic value.
ValueT newValue = ValueT::join(*value, rhs);
assert(ValueT::join(newValue, *value) == newValue &&
ValueT newValue = ValueT::join(value, rhs);
assert(ValueT::join(newValue, value) == newValue &&
"expected `join` to be monotonic");
assert(ValueT::join(newValue, rhs) == newValue &&
"expected `join` to be monotonic");
@ -131,17 +115,11 @@ public:
}
/// Print the lattice element.
void print(raw_ostream &os) const override {
if (value)
value->print(os);
else
os << "<NULL>";
}
void print(raw_ostream &os) const override { value.print(os); }
private:
/// The currently computed value that is optimistically assumed to be true,
/// or None if the lattice element is uninitialized.
Optional<ValueT> value;
/// The currently computed value that is optimistically assumed to be true.
ValueT value;
};
//===----------------------------------------------------------------------===//

View File

@ -291,9 +291,6 @@ public:
/// Returns the program point this static is located at.
ProgramPoint getPoint() const { return point; }
/// Returns true if the analysis state is uninitialized.
virtual bool isUninitialized() const = 0;
/// Print the contents of the analysis state.
virtual void print(raw_ostream &os) const = 0;

View File

@ -20,9 +20,15 @@ using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
void ConstantValue::print(raw_ostream &os) const {
if (constant)
return constant.print(os);
os << "<NO VALUE>";
if (isUninitialized()) {
os << "<UNINITIALIZED>";
return;
}
if (getConstantValue() == nullptr) {
os << "<UNKNOWN>";
return;
}
return getConstantValue().print(os);
}
//===----------------------------------------------------------------------===//
@ -45,8 +51,11 @@ void SparseConstantPropagation::visitOperation(
SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
for (auto *operandLattice : operands)
for (auto *operandLattice : operands) {
if (operandLattice->getValue().isUninitialized())
return;
constantOperands.push_back(operandLattice->getValue().getConstantValue());
}
// Save the original operands and attributes just in case the operation
// folds in-place. The constant passed in may not correspond to the real

View File

@ -318,7 +318,7 @@ static Optional<SmallVector<Attribute>> getOperandValuesImpl(
for (Value operand : op->getOperands()) {
const Lattice<ConstantValue> *cv = getLattice(operand);
// If any of the operands' values are uninitialized, bail out.
if (cv->isUninitialized())
if (cv->getValue().isUninitialized())
return {};
operands.push_back(cv->getValue().getConstantValue());
}

View File

@ -74,9 +74,6 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
before = getLatticeFor(op, prev);
else
before = getLatticeFor(op, op->getBlock());
// If the incoming lattice is uninitialized, bail out.
if (before->isUninitialized())
return;
// Invoke the operation transfer function.
visitOperationImpl(op, *before, after);

View File

@ -29,7 +29,7 @@ IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
APInt umax = APInt::getMaxValue(width);
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
return {{umin, umax, smin, smax}};
return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
}
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
@ -57,6 +57,13 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
void IntegerRangeAnalysis::visitOperation(
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) {
// If the lattice on any operand is unitialized, bail out.
if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
return lattice->getValue().isUninitialized();
})) {
return;
}
// Ignore non-integer outputs - return early if the op has no scalar
// integer results
bool hasIntegerResult = false;
@ -91,11 +98,9 @@ void IntegerRangeAnalysis::visitOperation(
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
Optional<IntegerValueRange> oldRange;
if (!lattice->isUninitialized())
oldRange = lattice->getValue();
IntegerValueRange oldRange = lattice->getValue();
ChangeResult changed = lattice->join(attrs);
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@ -104,8 +109,8 @@ void IntegerRangeAnalysis::visitOperation(
bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
if (isYieldedResult && oldRange.has_value() &&
!(lattice->getValue() == *oldRange)) {
if (isYieldedResult && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
@ -134,11 +139,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
Optional<IntegerValueRange> oldRange;
if (!lattice->isUninitialized())
oldRange = lattice->getValue();
IntegerValueRange oldRange = lattice->getValue();
ChangeResult changed = lattice->join(attrs);
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@ -147,7 +150,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
if (isYieldedValue && !oldRange.isUninitialized() &&
!(lattice->getValue() == oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
@ -212,7 +216,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
propagateIfChanged(ivEntry, ivEntry->join(ivRange));
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
return;
}

View File

@ -117,9 +117,6 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
for (Value operand : op->getOperands()) {
AbstractSparseLattice *operandLattice = getLatticeElement(operand);
operandLattice->useDefSubscribe(this);
// If any of the operand states are not initialized, bail out.
if (operandLattice->isUninitialized())
return;
operandLattices.push_back(operandLattice);
}

View File

@ -43,7 +43,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
if (!lattice || lattice->isUninitialized())
if (!lattice || lattice->getValue().isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
if (!latticeValue.getConstantValue())

View File

@ -21,16 +21,29 @@ namespace {
class UnderlyingValue {
public:
/// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue) : underlyingValue(underlyingValue) {}
explicit UnderlyingValue(Optional<Value> underlyingValue = None)
: underlyingValue(underlyingValue) {}
/// Whether the state is uninitialized.
bool isUninitialized() const { return !underlyingValue.has_value(); }
/// Returns the underlying value.
Value getUnderlyingValue() const { return underlyingValue; }
Value getUnderlyingValue() const {
assert(!isUninitialized());
return *underlyingValue;
}
/// Join two underlying values. If there are conflicting underlying values,
/// go to the pessimistic value.
static UnderlyingValue join(const UnderlyingValue &lhs,
const UnderlyingValue &rhs) {
return lhs.underlyingValue == rhs.underlyingValue ? lhs : Value();
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return lhs.underlyingValue == rhs.underlyingValue
? lhs
: UnderlyingValue(Value{});
}
/// Compare underlying values.
@ -41,7 +54,7 @@ public:
void print(raw_ostream &os) const { os << underlyingValue; }
private:
Value underlyingValue;
Optional<Value> underlyingValue;
};
/// This lattice represents, for a given memory resource, the potential last
@ -52,9 +65,6 @@ public:
using AbstractDenseLattice::AbstractDenseLattice;
/// The lattice is always initialized.
bool isUninitialized() const override { return false; }
/// Clear all modifications.
ChangeResult reset() {
if (lastMods.empty())
@ -169,7 +179,7 @@ static Value getMostUnderlyingValue(
const UnderlyingValueLattice *underlying;
do {
underlying = getUnderlyingValueFn(value);
if (!underlying || underlying->isUninitialized())
if (!underlying || underlying->getValue().isUninitialized())
return {};
Value underlyingValue = underlying->getValue().getUnderlyingValue();
if (underlyingValue == value)

View File

@ -21,7 +21,7 @@ public:
using AnalysisState::AnalysisState;
/// Returns true if the state is uninitialized.
bool isUninitialized() const override { return !state; }
bool isUninitialized() const { return !state; }
/// Print the integer value or "none" if uninitialized.
void print(raw_ostream &os) const override {

View File

@ -25,7 +25,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
OperationFolder &folder, Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->isUninitialized())
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return failure();
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();