forked from OSchip/llvm-project
[mlir][dataflow] Remove Lattice::isUninitialized().
Currently, for sparse analyses, we always store a `Optional<ValueT>` in each lattice element. When it's `None`, we consider the lattice element as `uninitialized`. However: * Not all lattices have an `uninitialized` state. For example, `Executable` and `PredecessorState` have default values so they are always initialized. * In dense analyses, we don't have the concept of an `uninitialized` state. Given these inconsistencies, this patch removes `Lattice::isUninitialized()`. Individual analysis states are now default-constructed. If the default state of an analysis can be considered as "uninitialized" then this analysis should implement the following logic: * Special join rule: `join(uninitialized, any) == any`. * Special bail out logic: if any of the input states is uninitialized, exit the transfer function early. Depends On D132086 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D132800
This commit is contained in:
parent
afa0ed33df
commit
47bf3e3812
|
@ -28,15 +28,24 @@ namespace dataflow {
|
|||
/// This lattice value represents a known constant value of a lattice.
|
||||
class ConstantValue {
|
||||
public:
|
||||
/// Construct a constant value as uninitialized.
|
||||
explicit ConstantValue() = default;
|
||||
|
||||
/// Construct a constant value with a known constant.
|
||||
ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
|
||||
: constant(knownValue), dialect(dialect) {}
|
||||
explicit ConstantValue(Attribute constant, Dialect *dialect)
|
||||
: constant(constant), dialect(dialect) {}
|
||||
|
||||
/// Get the constant value. Returns null if no value was determined.
|
||||
Attribute getConstantValue() const { return constant; }
|
||||
Attribute getConstantValue() const {
|
||||
assert(!isUninitialized());
|
||||
return *constant;
|
||||
}
|
||||
|
||||
/// Get the dialect instance that can be used to materialize the constant.
|
||||
Dialect *getConstantDialect() const { return dialect; }
|
||||
Dialect *getConstantDialect() const {
|
||||
assert(!isUninitialized());
|
||||
return dialect;
|
||||
}
|
||||
|
||||
/// Compare the constant values.
|
||||
bool operator==(const ConstantValue &rhs) const {
|
||||
|
@ -46,21 +55,36 @@ public:
|
|||
/// Print the constant value.
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
/// The state where the constant value is uninitialized. This happens when the
|
||||
/// state hasn't been set during the analysis.
|
||||
static ConstantValue getUninitialized() { return ConstantValue{}; }
|
||||
|
||||
/// Whether the state is uninitialized.
|
||||
bool isUninitialized() const { return !constant.has_value(); }
|
||||
|
||||
/// The state where the constant value is unknown.
|
||||
static ConstantValue getUnknownConstant() { return {}; }
|
||||
static ConstantValue getUnknownConstant() {
|
||||
return ConstantValue{/*constant=*/nullptr, /*dialect=*/nullptr};
|
||||
}
|
||||
|
||||
/// The union with another constant value is null if they are different, and
|
||||
/// the same if they are the same.
|
||||
static ConstantValue join(const ConstantValue &lhs,
|
||||
const ConstantValue &rhs) {
|
||||
return lhs == rhs ? lhs : ConstantValue();
|
||||
if (lhs.isUninitialized())
|
||||
return rhs;
|
||||
if (rhs.isUninitialized())
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return getUnknownConstant();
|
||||
}
|
||||
|
||||
private:
|
||||
/// The constant value.
|
||||
Attribute constant;
|
||||
/// An dialect instance that can be used to materialize the constant.
|
||||
Dialect *dialect;
|
||||
Optional<Attribute> constant;
|
||||
/// A dialect instance that can be used to materialize the constant.
|
||||
Dialect *dialect = nullptr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -38,9 +38,6 @@ class Executable : public AnalysisState {
|
|||
public:
|
||||
using AnalysisState::AnalysisState;
|
||||
|
||||
/// The state is initialized by default.
|
||||
bool isUninitialized() const override { return false; }
|
||||
|
||||
/// Set the state of the program point to live.
|
||||
ChangeResult setToLive();
|
||||
|
||||
|
@ -95,9 +92,6 @@ class PredecessorState : public AnalysisState {
|
|||
public:
|
||||
using AnalysisState::AnalysisState;
|
||||
|
||||
/// The state is initialized by default.
|
||||
bool isUninitialized() const override { return false; }
|
||||
|
||||
/// Print the known predecessors.
|
||||
void print(raw_ostream &os) const override;
|
||||
|
||||
|
|
|
@ -30,10 +30,18 @@ public:
|
|||
static IntegerValueRange getMaxRange(Value value);
|
||||
|
||||
/// Create an integer value range lattice value.
|
||||
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
|
||||
IntegerValueRange(Optional<ConstantIntRanges> value = None)
|
||||
: value(std::move(value)) {}
|
||||
|
||||
/// Whether the range is uninitialized. This happens when the state hasn't
|
||||
/// been set during the analysis.
|
||||
bool isUninitialized() const { return !value.has_value(); }
|
||||
|
||||
/// Get the known integer value range.
|
||||
const ConstantIntRanges &getValue() const { return value; }
|
||||
const ConstantIntRanges &getValue() const {
|
||||
assert(!isUninitialized());
|
||||
return *value;
|
||||
}
|
||||
|
||||
/// Compare two ranges.
|
||||
bool operator==(const IntegerValueRange &rhs) const {
|
||||
|
@ -43,7 +51,11 @@ public:
|
|||
/// Take the union of two ranges.
|
||||
static IntegerValueRange join(const IntegerValueRange &lhs,
|
||||
const IntegerValueRange &rhs) {
|
||||
return lhs.value.rangeUnion(rhs.value);
|
||||
if (lhs.isUninitialized())
|
||||
return rhs;
|
||||
if (rhs.isUninitialized())
|
||||
return lhs;
|
||||
return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())};
|
||||
}
|
||||
|
||||
/// Print the integer value range.
|
||||
|
@ -51,7 +63,7 @@ public:
|
|||
|
||||
private:
|
||||
/// The known integer value range.
|
||||
ConstantIntRanges value;
|
||||
Optional<ConstantIntRanges> value;
|
||||
};
|
||||
|
||||
/// This lattice element represents the integer value range of an SSA value.
|
||||
|
|
|
@ -81,27 +81,17 @@ public:
|
|||
|
||||
/// Return the value held by this lattice. This requires that the value is
|
||||
/// initialized.
|
||||
ValueT &getValue() {
|
||||
assert(!isUninitialized() && "expected known lattice element");
|
||||
return *value;
|
||||
}
|
||||
ValueT &getValue() { return value; }
|
||||
const ValueT &getValue() const {
|
||||
return const_cast<Lattice<ValueT> *>(this)->getValue();
|
||||
}
|
||||
|
||||
/// Returns true if the value of this lattice hasn't yet been initialized.
|
||||
bool isUninitialized() const override { return !value.has_value(); }
|
||||
|
||||
/// Join the information contained in the 'rhs' lattice into this
|
||||
/// lattice. Returns if the state of the current lattice changed.
|
||||
ChangeResult join(const AbstractSparseLattice &rhs) override {
|
||||
const Lattice<ValueT> &rhsLattice =
|
||||
static_cast<const Lattice<ValueT> &>(rhs);
|
||||
|
||||
// If rhs is uninitialized, there is nothing to do.
|
||||
if (rhsLattice.isUninitialized())
|
||||
return ChangeResult::NoChange;
|
||||
|
||||
// Join the rhs value into this lattice.
|
||||
return join(rhsLattice.getValue());
|
||||
}
|
||||
|
@ -109,15 +99,9 @@ public:
|
|||
/// Join the information contained in the 'rhs' value into this
|
||||
/// lattice. Returns if the state of the current lattice changed.
|
||||
ChangeResult join(const ValueT &rhs) {
|
||||
// If the current lattice is uninitialized, copy the rhs value.
|
||||
if (isUninitialized()) {
|
||||
value = rhs;
|
||||
return ChangeResult::Change;
|
||||
}
|
||||
|
||||
// Otherwise, join rhs with the current optimistic value.
|
||||
ValueT newValue = ValueT::join(*value, rhs);
|
||||
assert(ValueT::join(newValue, *value) == newValue &&
|
||||
ValueT newValue = ValueT::join(value, rhs);
|
||||
assert(ValueT::join(newValue, value) == newValue &&
|
||||
"expected `join` to be monotonic");
|
||||
assert(ValueT::join(newValue, rhs) == newValue &&
|
||||
"expected `join` to be monotonic");
|
||||
|
@ -131,17 +115,11 @@ public:
|
|||
}
|
||||
|
||||
/// Print the lattice element.
|
||||
void print(raw_ostream &os) const override {
|
||||
if (value)
|
||||
value->print(os);
|
||||
else
|
||||
os << "<NULL>";
|
||||
}
|
||||
void print(raw_ostream &os) const override { value.print(os); }
|
||||
|
||||
private:
|
||||
/// The currently computed value that is optimistically assumed to be true,
|
||||
/// or None if the lattice element is uninitialized.
|
||||
Optional<ValueT> value;
|
||||
/// The currently computed value that is optimistically assumed to be true.
|
||||
ValueT value;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -291,9 +291,6 @@ public:
|
|||
/// Returns the program point this static is located at.
|
||||
ProgramPoint getPoint() const { return point; }
|
||||
|
||||
/// Returns true if the analysis state is uninitialized.
|
||||
virtual bool isUninitialized() const = 0;
|
||||
|
||||
/// Print the contents of the analysis state.
|
||||
virtual void print(raw_ostream &os) const = 0;
|
||||
|
||||
|
|
|
@ -20,9 +20,15 @@ using namespace mlir::dataflow;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConstantValue::print(raw_ostream &os) const {
|
||||
if (constant)
|
||||
return constant.print(os);
|
||||
os << "<NO VALUE>";
|
||||
if (isUninitialized()) {
|
||||
os << "<UNINITIALIZED>";
|
||||
return;
|
||||
}
|
||||
if (getConstantValue() == nullptr) {
|
||||
os << "<UNKNOWN>";
|
||||
return;
|
||||
}
|
||||
return getConstantValue().print(os);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -45,8 +51,11 @@ void SparseConstantPropagation::visitOperation(
|
|||
|
||||
SmallVector<Attribute, 8> constantOperands;
|
||||
constantOperands.reserve(op->getNumOperands());
|
||||
for (auto *operandLattice : operands)
|
||||
for (auto *operandLattice : operands) {
|
||||
if (operandLattice->getValue().isUninitialized())
|
||||
return;
|
||||
constantOperands.push_back(operandLattice->getValue().getConstantValue());
|
||||
}
|
||||
|
||||
// Save the original operands and attributes just in case the operation
|
||||
// folds in-place. The constant passed in may not correspond to the real
|
||||
|
|
|
@ -318,7 +318,7 @@ static Optional<SmallVector<Attribute>> getOperandValuesImpl(
|
|||
for (Value operand : op->getOperands()) {
|
||||
const Lattice<ConstantValue> *cv = getLattice(operand);
|
||||
// If any of the operands' values are uninitialized, bail out.
|
||||
if (cv->isUninitialized())
|
||||
if (cv->getValue().isUninitialized())
|
||||
return {};
|
||||
operands.push_back(cv->getValue().getConstantValue());
|
||||
}
|
||||
|
|
|
@ -74,9 +74,6 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
|
|||
before = getLatticeFor(op, prev);
|
||||
else
|
||||
before = getLatticeFor(op, op->getBlock());
|
||||
// If the incoming lattice is uninitialized, bail out.
|
||||
if (before->isUninitialized())
|
||||
return;
|
||||
|
||||
// Invoke the operation transfer function.
|
||||
visitOperationImpl(op, *before, after);
|
||||
|
|
|
@ -29,7 +29,7 @@ IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
|
|||
APInt umax = APInt::getMaxValue(width);
|
||||
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
|
||||
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
|
||||
return {{umin, umax, smin, smax}};
|
||||
return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
|
||||
}
|
||||
|
||||
void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
|
||||
|
@ -57,6 +57,13 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
|
|||
void IntegerRangeAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
|
||||
ArrayRef<IntegerValueRangeLattice *> results) {
|
||||
// If the lattice on any operand is unitialized, bail out.
|
||||
if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) {
|
||||
return lattice->getValue().isUninitialized();
|
||||
})) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore non-integer outputs - return early if the op has no scalar
|
||||
// integer results
|
||||
bool hasIntegerResult = false;
|
||||
|
@ -91,11 +98,9 @@ void IntegerRangeAnalysis::visitOperation(
|
|||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
|
||||
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
|
||||
Optional<IntegerValueRange> oldRange;
|
||||
if (!lattice->isUninitialized())
|
||||
oldRange = lattice->getValue();
|
||||
IntegerValueRange oldRange = lattice->getValue();
|
||||
|
||||
ChangeResult changed = lattice->join(attrs);
|
||||
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
|
||||
|
||||
// Catch loop results with loop variant bounds and conservatively make
|
||||
// them [-inf, inf] so we don't circle around infinitely often (because
|
||||
|
@ -104,8 +109,8 @@ void IntegerRangeAnalysis::visitOperation(
|
|||
bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
|
||||
return op->hasTrait<OpTrait::IsTerminator>();
|
||||
});
|
||||
if (isYieldedResult && oldRange.has_value() &&
|
||||
!(lattice->getValue() == *oldRange)) {
|
||||
if (isYieldedResult && !oldRange.isUninitialized() &&
|
||||
!(lattice->getValue() == oldRange)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
|
||||
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
|
||||
}
|
||||
|
@ -134,11 +139,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
|||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
|
||||
IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
|
||||
Optional<IntegerValueRange> oldRange;
|
||||
if (!lattice->isUninitialized())
|
||||
oldRange = lattice->getValue();
|
||||
IntegerValueRange oldRange = lattice->getValue();
|
||||
|
||||
ChangeResult changed = lattice->join(attrs);
|
||||
ChangeResult changed = lattice->join(IntegerValueRange{attrs});
|
||||
|
||||
// Catch loop results with loop variant bounds and conservatively make
|
||||
// them [-inf, inf] so we don't circle around infinitely often (because
|
||||
|
@ -147,7 +150,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
|||
bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
|
||||
return op->hasTrait<OpTrait::IsTerminator>();
|
||||
});
|
||||
if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
|
||||
if (isYieldedValue && !oldRange.isUninitialized() &&
|
||||
!(lattice->getValue() == oldRange)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
|
||||
changed |= lattice->join(IntegerValueRange::getMaxRange(v));
|
||||
}
|
||||
|
@ -212,7 +216,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
|||
|
||||
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
|
||||
auto ivRange = ConstantIntRanges::fromSigned(min, max);
|
||||
propagateIfChanged(ivEntry, ivEntry->join(ivRange));
|
||||
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -117,9 +117,6 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
|
|||
for (Value operand : op->getOperands()) {
|
||||
AbstractSparseLattice *operandLattice = getLatticeElement(operand);
|
||||
operandLattice->useDefSubscribe(this);
|
||||
// If any of the operand states are not initialized, bail out.
|
||||
if (operandLattice->isUninitialized())
|
||||
return;
|
||||
operandLattices.push_back(operandLattice);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
|
|||
OpBuilder &builder,
|
||||
OperationFolder &folder, Value value) {
|
||||
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
|
||||
if (!lattice || lattice->isUninitialized())
|
||||
if (!lattice || lattice->getValue().isUninitialized())
|
||||
return failure();
|
||||
const ConstantValue &latticeValue = lattice->getValue();
|
||||
if (!latticeValue.getConstantValue())
|
||||
|
|
|
@ -21,16 +21,29 @@ namespace {
|
|||
class UnderlyingValue {
|
||||
public:
|
||||
/// Create an underlying value state with a known underlying value.
|
||||
UnderlyingValue(Value underlyingValue) : underlyingValue(underlyingValue) {}
|
||||
explicit UnderlyingValue(Optional<Value> underlyingValue = None)
|
||||
: underlyingValue(underlyingValue) {}
|
||||
|
||||
/// Whether the state is uninitialized.
|
||||
bool isUninitialized() const { return !underlyingValue.has_value(); }
|
||||
|
||||
/// Returns the underlying value.
|
||||
Value getUnderlyingValue() const { return underlyingValue; }
|
||||
Value getUnderlyingValue() const {
|
||||
assert(!isUninitialized());
|
||||
return *underlyingValue;
|
||||
}
|
||||
|
||||
/// Join two underlying values. If there are conflicting underlying values,
|
||||
/// go to the pessimistic value.
|
||||
static UnderlyingValue join(const UnderlyingValue &lhs,
|
||||
const UnderlyingValue &rhs) {
|
||||
return lhs.underlyingValue == rhs.underlyingValue ? lhs : Value();
|
||||
if (lhs.isUninitialized())
|
||||
return rhs;
|
||||
if (rhs.isUninitialized())
|
||||
return lhs;
|
||||
return lhs.underlyingValue == rhs.underlyingValue
|
||||
? lhs
|
||||
: UnderlyingValue(Value{});
|
||||
}
|
||||
|
||||
/// Compare underlying values.
|
||||
|
@ -41,7 +54,7 @@ public:
|
|||
void print(raw_ostream &os) const { os << underlyingValue; }
|
||||
|
||||
private:
|
||||
Value underlyingValue;
|
||||
Optional<Value> underlyingValue;
|
||||
};
|
||||
|
||||
/// This lattice represents, for a given memory resource, the potential last
|
||||
|
@ -52,9 +65,6 @@ public:
|
|||
|
||||
using AbstractDenseLattice::AbstractDenseLattice;
|
||||
|
||||
/// The lattice is always initialized.
|
||||
bool isUninitialized() const override { return false; }
|
||||
|
||||
/// Clear all modifications.
|
||||
ChangeResult reset() {
|
||||
if (lastMods.empty())
|
||||
|
@ -169,7 +179,7 @@ static Value getMostUnderlyingValue(
|
|||
const UnderlyingValueLattice *underlying;
|
||||
do {
|
||||
underlying = getUnderlyingValueFn(value);
|
||||
if (!underlying || underlying->isUninitialized())
|
||||
if (!underlying || underlying->getValue().isUninitialized())
|
||||
return {};
|
||||
Value underlyingValue = underlying->getValue().getUnderlyingValue();
|
||||
if (underlyingValue == value)
|
||||
|
|
|
@ -21,7 +21,7 @@ public:
|
|||
using AnalysisState::AnalysisState;
|
||||
|
||||
/// Returns true if the state is uninitialized.
|
||||
bool isUninitialized() const override { return !state; }
|
||||
bool isUninitialized() const { return !state; }
|
||||
|
||||
/// Print the integer value or "none" if uninitialized.
|
||||
void print(raw_ostream &os) const override {
|
||||
|
|
|
@ -25,7 +25,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
|
|||
OperationFolder &folder, Value value) {
|
||||
auto *maybeInferredRange =
|
||||
solver.lookupState<IntegerValueRangeLattice>(value);
|
||||
if (!maybeInferredRange || maybeInferredRange->isUninitialized())
|
||||
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
|
||||
return failure();
|
||||
const ConstantIntRanges &inferredRange =
|
||||
maybeInferredRange->getValue().getValue();
|
||||
|
|
Loading…
Reference in New Issue