[mlir] Add integer range inference analysis

This commit defines a dataflow analysis for integer ranges, which
uses a newly-added InferIntRangeInterface to compute the lower and
upper bounds on the results of an operation from the bounds on the
arguments. The range inference is a flow-insensitive dataflow analysis
that can be used to simplify code, such as by statically identifying
bounds checks that cannot fail in order to eliminate them.

The InferIntRangeInterface has one method, inferResultRanges(), which
takes a vector of inferred ranges for each argument to an op
implementing the interface and a callback allowing the implementation
to define the ranges for each result. These ranges are stored as
ConstantIntRanges, which hold the lower and upper bounds for a
value. Bounds are tracked separately for the signed and unsigned
interpretations of a value, which ensures that the impact of
arithmetic overflows is correctly tracked during the analysis.

The commit also adds a -test-int-range-inference pass to test the
analysis until it is integrated into SCCP or otherwise exposed.

Finally, this commit fixes some bugs relating to the handling of
region iteration arguments and terminators in the data flow analysis
framework.

Depends on D124020

Depends on D124021

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D124023
This commit is contained in:
Krzysztof Drewniak 2022-06-02 19:04:42 +00:00
parent 8601f269f1
commit 1350c9887d
19 changed files with 1110 additions and 8 deletions

View File

@ -0,0 +1,41 @@
//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the dataflow analysis class for integer range inference
// so that it can be used in transformations over the `arith` dialect such as
// branch elimination or signed->unsigned rewriting
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
namespace detail {
class IntRangeAnalysisImpl;
} // end namespace detail
class IntRangeAnalysis {
public:
/// Analyze all operations rooted under (but not including)
/// `topLevelOperation`.
IntRangeAnalysis(Operation *topLevelOperation);
IntRangeAnalysis(IntRangeAnalysis &&other);
~IntRangeAnalysis();
/// Get inferred range for value `v` if one exists.
Optional<ConstantIntRanges> getResult(Value v);
private:
std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
};
} // end namespace mlir
#endif

View File

@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(SideEffectInterfaces)

View File

@ -0,0 +1,98 @@
//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions of the integer range inference interface
// defined in `InferIntRange.td`
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
namespace mlir {
/// A set of arbitrary-precision integers representing bounds on a given integer
/// value. These bounds are inclusive on both ends, so
/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
/// the unsigned and signed interpretations of values in order to enable more
/// precice inference of the interplay between operations with signed and
/// unsigned semantics.
class ConstantIntRanges {
public:
/// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
/// Non-integer values should be bounded by APInts of bitwidth 0.
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
const APInt &smax)
: uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
umaxVal.getBitWidth() == sminVal.getBitWidth() &&
sminVal.getBitWidth() == smaxVal.getBitWidth() &&
"All bounds in the ranges must have the same bitwidth");
}
bool operator==(const ConstantIntRanges &other) const;
/// The minimum value of an integer when it is interpreted as unsigned.
const APInt &umin() const;
/// The maximum value of an integer when it is interpreted as unsigned.
const APInt &umax() const;
/// The minimum value of an integer when it is interpreted as signed.
const APInt &smin() const;
/// The maximum value of an integer when it is interpreted as signed.
const APInt &smax() const;
/// Return the bitwidth that should be used for integer ranges describing
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
/// this is the internal storage bitwidth of `index` attributes, and for
/// non-integer types this is 0.
static unsigned getStorageBitwidth(Type type);
/// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
/// minimum and `max` is both the signed and unsigned maximum.
static ConstantIntRanges range(const APInt &min, const APInt &max);
/// Create an `IntRangeAttrs` with the signed minimum and maximum equal
/// to `smin` and `smax`, where the unsigned bounds are constructed from the
/// signed ones if they correspond to a contigious range of bit patterns when
/// viewed as unsigned values and are left at [0, int_max()] otherwise.
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
/// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
/// to `umin` and `umax` and the signed part equal to `umin` and `umax`
/// unless the sign bit changes between the minimum and maximum.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
/// Returns the union (computed separately for signed and unsigned bounds)
/// of `a` and `b`.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
/// If either the signed or unsigned interpretations of the range
/// indicate that the value it bounds is a constant, return that constant
/// value.
Optional<APInt> getConstantValue() const;
friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);
private:
APInt uminVal, umaxVal, sminVal, smaxVal;
};
/// The type of the `setResultRanges` callback provided to ops implementing
/// InferIntRangeInterface. It should be called once for each integer result
/// value and be passed the ConstantIntRanges corresponding to that value.
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
} // end namespace mlir
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H

View File

@ -0,0 +1,52 @@
//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-----------------------------------------------------===//
//
// Defines the interface for range analysis on scalar integers
//
//===-----------------------------------------------------===//
#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE
include "mlir/IR/OpBase.td"
def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
let description = [{
Allows operations to participate in range analysis for scalar integer values by
providing a methods that allows them to specify lower and upper bounds on their
result(s) given lower and upper bounds on their input(s) if known.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Infer the bounds on the results of this op given the bounds on its arguments.
For each result value or block argument (that isn't a branch argument,
since the dataflow analysis handles those case), the method should call
`setValueRange` with that `Value` as an argument. When `setValueRange`
is not called for some value, it will recieve a default value of the mimimum
and maximum values forits type (the unbounded range).
When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
of the branch results, as this will be handled by the analyses that use
this interface.
This function will only be called when at least one result of the op is a
scalar integer value or the op has a region.
`argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
order. Non-integer arguments will have the an unbounded range of width-0
APInts in their `argRanges` element.
}],
"void", "inferResultRanges", (ins
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
"::mlir::SetIntRangeFn":$setResultRanges)
>];
}
#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE

View File

@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
CallGraph.cpp
DataFlowAnalysis.cpp
DataLayoutAnalysis.cpp
IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@ -16,6 +17,7 @@ add_mlir_library(MLIRAnalysis
CallGraph.cpp
DataFlowAnalysis.cpp
DataLayoutAnalysis.cpp
IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@ -31,6 +33,7 @@ add_mlir_library(MLIRAnalysis
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
MLIRViewLikeInterface

View File

@ -359,11 +359,20 @@ void ForwardDataFlowSolver::visitOperation(Operation *op) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
return visitRegionBranchOperation(branch, operandLattices);
// If we can't, conservatively mark all regions as executable.
// TODO: Let the `visitOperation` method decide how to propagate
// information to the block arguments.
for (Region &region : op->getRegions())
markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/true);
for (Region &region : op->getRegions()) {
analysis.visitNonControlFlowArguments(op, RegionSuccessor(&region),
operandLattices);
// `visitNonControlFlowArguments` is required to define all of the region
// argument lattices.
assert(llvm::none_of(
region.getArguments(),
[&](Value value) {
return analysis.getLatticeElement(value).isUninitialized();
}) &&
"expected `visitNonControlFlowArguments` to define all argument "
"lattices");
markEntryBlockExecutable(&region, /*markPessimisticFixpoint=*/false);
}
}
// If this op produces no results, it can't produce any constants.
@ -567,12 +576,45 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
return;
// If the branch is a RegionBranchTerminatorOpInterface,
// construct the set of operand lattices as the set of non control-flow
// arguments of the parent and the values this op returns. This allows
// for the correct lattices to be passed to getSuccessorsForOperands()
// in cases such as scf.while.
ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
SmallVector<AbstractLatticeElement *, 0> parentLattices;
if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
parentLattices.reserve(regionInterface->getNumOperands());
for (Value parentOperand : regionInterface->getOperands()) {
AbstractLatticeElement *operandLattice =
analysis.lookupLatticeElement(parentOperand);
if (!operandLattice || operandLattice->isUninitialized())
return;
parentLattices.push_back(operandLattice);
}
unsigned regionNumber = parentRegion->getRegionNumber();
OperandRange iterArgs =
regionInterface.getSuccessorEntryOperands(regionNumber);
OperandRange terminatorArgs =
regionTerminator.getSuccessorOperands(regionNumber);
assert(iterArgs.size() == terminatorArgs.size() &&
"Number of iteration arguments for region should equal number of "
"those arguments defined by terminator");
if (!iterArgs.empty()) {
unsigned iterStart = iterArgs.getBeginOperandIndex();
unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
}
branchOpLattices = parentLattices;
}
// Query the set of successors of the current region using the current
// optimistic lattice state.
SmallVector<RegionSuccessor, 1> regionSuccessors;
analysis.getSuccessorsForOperands(regionInterface,
parentRegion->getRegionNumber(),
operandLattices, regionSuccessors);
branchOpLattices, regionSuccessors);
if (regionSuccessors.empty())
return;
@ -584,7 +626,7 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
// region index (if any).
return *getRegionBranchSuccessorOperands(op, regionIndex);
};
return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
getOperands);
}

View File

@ -0,0 +1,325 @@
//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the dataflow analysis class for integer range inference
// which is used in transformations over the `arith` dialect such as
// branch elimination or signed->unsigned rewriting
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/IntRangeAnalysis.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "int-range-analysis"
using namespace mlir;
namespace {
/// A wrapper around ConstantIntRanges that provides the lattice functions
/// expected by dataflow analysis.
struct IntRangeLattice {
IntRangeLattice(const ConstantIntRanges &value) : value(value){};
IntRangeLattice(ConstantIntRanges &&value) : value(value){};
bool operator==(const IntRangeLattice &other) const {
return value == other.value;
}
/// wrapper around rangeUnion()
static IntRangeLattice join(const IntRangeLattice &a,
const IntRangeLattice &b) {
return a.value.rangeUnion(b.value);
}
/// Creates a range with bitwidth 0 to represent that we don't know if the
/// value being marked overdefined is even an integer.
static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
APInt noIntValue = APInt::getZeroWidth();
return ConstantIntRanges::range(noIntValue, noIntValue);
}
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value v as unable to be analyzed further,
/// where t is the type of v.
static IntRangeLattice getPessimisticValueState(Value v) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
APInt umin = APInt::getMinValue(width);
APInt umax = APInt::getMaxValue(width);
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
return ConstantIntRanges{umin, umax, smin, smax};
}
ConstantIntRanges value;
};
} // end anonymous namespace
namespace mlir {
namespace detail {
class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
public:
/// Define bounds on the results or block arguments of the operation
/// based on the bounds on the arguments given in `operands`
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
/// Skip regions of branch ops when we can statically infer constant
/// values for operands to the branch op and said op tells us it's safe to do
/// so.
LogicalResult
getSuccessorsForOperands(BranchOpInterface branch,
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
SmallVectorImpl<Block *> &successors) final;
/// Skip regions of branch or loop ops when we can statically infer constant
/// values for operands to the branch op and said op tells us it's safe to do
/// so.
void
getSuccessorsForOperands(RegionBranchOpInterface branch,
Optional<unsigned> sourceIndex,
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
SmallVectorImpl<RegionSuccessor> &successors) final;
/// Call the InferIntRangeInterface implementation for region-using ops
/// that implement it, and infer the bounds of loop induction variables
/// for ops that implement LoopLikeOPInterface.
ChangeResult visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &region,
ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
};
} // end namespace detail
} // end namespace mlir
/// Given the results of getConstant{Lower,Upper}Bound()
/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
/// that result if possible.
static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
Type boundType,
detail::IntRangeAnalysisImpl &analysis,
bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
if (loopBound.hasValue()) {
if (loopBound->is<Attribute>()) {
if (auto bound =
loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
return bound.getValue();
} else if (loopBound->is<Value>()) {
LatticeElement<IntRangeLattice> *lattice =
analysis.lookupLatticeElement(loopBound->get<Value>());
if (lattice != nullptr)
return getUpper ? lattice->getValue().value.smax()
: lattice->getValue().value.smin();
}
}
return getUpper ? APInt::getSignedMaxValue(width)
: APInt::getSignedMinValue(width);
}
ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
ChangeResult result = ChangeResult::NoChange;
// Ignore non-integer outputs - return early if the op has no scalar
// integer results
bool hasIntegerResult = false;
for (Value v : op->getResults()) {
if (v.getType().isIntOrIndex())
hasIntegerResult = true;
else
result |= markAllPessimisticFixpoint(v);
}
if (!hasIntegerResult)
return result;
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
LLVM_DEBUG(inferrable->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
SmallVector<ConstantIntRanges> argRanges(
llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
return val->getValue().value;
}));
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
Optional<IntRangeLattice> oldRange;
if (!lattice.isUninitialized())
oldRange = lattice.getValue();
result |= lattice.join(IntRangeLattice(attrs));
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
// the dataflow analysis in MLIR doesn't attempt to work out trip counts
// and often can't).
bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
if (isYieldedResult && oldRange.hasValue() &&
!(lattice.getValue() == *oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
result |= lattice.markPessimisticFixpoint();
}
};
inferrable.inferResultRanges(argRanges, joinCallback);
for (Value opResult : op->getResults()) {
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
// setResultRange() not called, make pessimistic.
if (lattice.isUninitialized())
result |= lattice.markPessimisticFixpoint();
}
} else if (op->getNumRegions() == 0) {
// No regions + no result inference method -> unbounded results (ex. memory
// ops)
result |= markAllPessimisticFixpoint(op->getResults());
}
return result;
}
LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
BranchOpInterface branch,
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
SmallVectorImpl<Block *> &successors) {
auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
Optional<APInt> maybeConstValue =
enumPair.value()->getValue().value.getConstantValue();
if (maybeConstValue) {
return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
*maybeConstValue);
}
return {};
};
SmallVector<Attribute> inferredConsts(
llvm::map_range(llvm::enumerate(operands), toConstantAttr));
if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
successors.push_back(singleSucc);
return success();
}
return failure();
}
void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
SmallVectorImpl<RegionSuccessor> &successors) {
auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
Optional<APInt> maybeConstValue =
enumPair.value()->getValue().value.getConstantValue();
if (maybeConstValue) {
return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
*maybeConstValue);
}
return {};
};
SmallVector<Attribute> inferredConsts(
llvm::map_range(llvm::enumerate(operands), toConstantAttr));
branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
}
ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &region,
ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
LLVM_DEBUG(inferrable->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
SmallVector<ConstantIntRanges> argRanges(
llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
return val->getValue().value;
}));
ChangeResult result = ChangeResult::NoChange;
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
Optional<IntRangeLattice> oldRange;
if (!lattice.isUninitialized())
oldRange = lattice.getValue();
result |= lattice.join(IntRangeLattice(attrs));
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
// the dataflow analysis in MLIR doesn't attempt to work out trip counts
// and often can't).
bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
if (isYieldedValue && oldRange.hasValue() &&
!(lattice.getValue() == *oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
result |= lattice.markPessimisticFixpoint();
}
};
inferrable.inferResultRanges(argRanges, joinCallback);
for (Value regionArg : region.getSuccessor()->getArguments()) {
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
// setResultRange() not called, make pessimistic.
if (lattice.isUninitialized())
result |= lattice.markPessimisticFixpoint();
}
return result;
}
// Infer bounds for loop arguments that have static bounds
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
Optional<Value> iv = loop.getSingleInductionVar();
if (!iv.hasValue()) {
return ForwardDataFlowAnalysis<
IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
}
Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
Optional<OpFoldResult> step = loop.getSingleStep();
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
/*getUpper=*/false);
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
if (stepVal.isNegative()) {
std::swap(min, max);
} else {
// Correct the upper bound by subtracting 1 so that it becomes a <= bound,
// because loops do not generally include their upper bound.
max -= 1;
}
LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
}
return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
op, region, operands);
}
IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
topLevelOperation->getContext());
impl->run(topLevelOperation);
}
IntRangeAnalysis::~IntRangeAnalysis() = default;
IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
if (result == nullptr || result->isUninitialized())
return llvm::None;
return result->getValue().value;
}

View File

@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
CopyOpInterface.cpp
DataLayoutInterfaces.cpp
DerivedAttributeOpInterface.cpp
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
SideEffectInterfaces.cpp
@ -35,6 +36,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)

View File

@ -0,0 +1,99 @@
//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
using namespace mlir;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
return umin().getBitWidth() == other.umin().getBitWidth() &&
umin() == other.umin() && umax() == other.umax() &&
smin() == other.smin() && smax() == other.smax();
}
const APInt &ConstantIntRanges::umin() const { return uminVal; }
const APInt &ConstantIntRanges::umax() const { return umaxVal; }
const APInt &ConstantIntRanges::smin() const { return sminVal; }
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
if (type.isIndex())
return IndexType::kInternalStorageBitWidth;
if (auto integerType = type.dyn_cast<IntegerType>())
return integerType.getWidth();
// Non-integer types have their bounds stored in width 0 `APInt`s.
return 0;
}
ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
return {min, max, min, max};
}
ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
const APInt &smax) {
unsigned int width = smin.getBitWidth();
APInt umin, umax;
if (smin.isNonNegative() == smax.isNonNegative()) {
umin = smin.ult(smax) ? smin : smax;
umax = smin.ugt(smax) ? smin : smax;
} else {
umin = APInt::getMinValue(width);
umax = APInt::getMaxValue(width);
}
return {umin, umax, smin, smax};
}
ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
const APInt &umax) {
unsigned int width = umin.getBitWidth();
APInt smin, smax;
if (umin.isNonNegative() == umax.isNonNegative()) {
smin = umin.slt(umax) ? umin : umax;
smax = umin.sgt(umax) ? umin : umax;
} else {
smin = APInt::getSignedMinValue(width);
smax = APInt::getSignedMaxValue(width);
}
return {umin, umax, smin, smax};
}
ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
if (umin().getBitWidth() == 0)
return *this;
if (other.umin().getBitWidth() == 0)
return other;
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}
Optional<APInt> ConstantIntRanges::getConstantValue() const {
// Note: we need to exclude the trivially-equal width 0 values here.
if (umin() == umax() && umin().getBitWidth() != 0)
return umin();
if (smin() == smax() && smin().getBitWidth() != 0)
return smin();
return None;
}
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
return os << "unsigned : [" << range.umin() << ", " << range.umax()
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
}

View File

@ -0,0 +1,102 @@
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
// CHECK: return %[[cst]]
func.func @constant() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
smin = 3 : index, smax = 3 : index}
func.return %0 : index
}
// CHECK-LABEL: func @increment
// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
// CHECK: return %[[cst]]
func.func @increment() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
%1 = test.increment %0
func.return %1 : index
}
// CHECK-LABEL: func @maybe_increment
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
func.func @maybe_increment(%arg0 : i1) -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
smin = 3 : index, smax = 3 : index}
%1 = scf.if %arg0 -> index {
scf.yield %0 : index
} else {
%2 = test.increment %0
scf.yield %2 : index
}
%3 = test.reflect_bounds %1
func.return %3 : index
}
// CHECK-LABEL: func @maybe_increment_br
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
func.func @maybe_increment_br(%arg0 : i1) -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
smin = 3 : index, smax = 3 : index}
cf.cond_br %arg0, ^bb0, ^bb1
^bb0:
%1 = test.increment %0
cf.br ^bb2(%1 : index)
^bb1:
cf.br ^bb2(%0 : index)
^bb2(%2 : index):
%3 = test.reflect_bounds %2
func.return %3 : index
}
// CHECK-LABEL: func @for_bounds
// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
func.func @for_bounds() -> index {
%c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index}
%c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
smin = 1 : index, smax = 1 : index}
%c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
smin = 2 : index, smax = 2 : index}
%0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
scf.yield %arg0 : index
}
%1 = test.reflect_bounds %0
func.return %1 : index
}
// CHECK-LABEL: func @no_analysis_of_loop_variants
// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
func.func @no_analysis_of_loop_variants() -> index {
%c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index}
%c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
smin = 1 : index, smax = 1 : index}
%c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
smin = 2 : index, smax = 2 : index}
%0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
%1 = test.increment %arg2
scf.yield %1 : index
}
%2 = test.reflect_bounds %0
func.return %2 : index
}
// CHECK-LABEL: func @region_args
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
func.func @region_args() {
test.with_bounds_region { umin = 3 : index, umax = 4 : index,
smin = 3 : index, smax = 4 : index } %arg0 {
%0 = test.reflect_bounds %arg0
}
func.return
}
// CHECK-LABEL: func @func_args_unbound
// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
func.func @func_args_unbound(%arg0 : index) -> index {
%0 = test.reflect_bounds %arg0
func.return %0 : index
}

View File

@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
MLIRFunc
MLIRFuncTransforms
MLIRIR
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRLinalg
MLIRLinalgTransforms

View File

@ -14,15 +14,21 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
@ -1396,6 +1402,66 @@ LogicalResult TestVerifiersOp::verifyRegions() {
return success();
}
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
}
ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the input argument
OpAsmParser::Argument argInfo;
argInfo.type = parser.getBuilder().getIndexType();
parser.parseArgument(argInfo);
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
}
void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
p << ' ';
p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
/*omitType=*/true);
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
void TestWithBoundsRegionOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
Value arg = getRegion().getArgument(0);
setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
}
void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
const ConstantIntRanges &range = argRanges[0];
APInt one(range.umin().getBitWidth(), 1);
setResultRanges(getResult(),
{range.umin().uadd_sat(one), range.umax().uadd_sat(one),
range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
}
void TestReflectBoundsOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
const ConstantIntRanges &range = argRanges[0];
MLIRContext *ctx = getContext();
Builder b(ctx);
setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
setResultRanges(getResult(), range);
}
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestOpStructs.cpp.inc"

View File

@ -33,6 +33,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -789,7 +790,7 @@ def StringAttrPrettyNameOp
def CustomResultsNameOp
: TEST_Op<"custom_result_name",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let arguments = (ins
let arguments = (ins
Variadic<AnyInteger>:$optional,
StrArrayAttr:$names
);
@ -2885,4 +2886,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
}];
}
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
def TestWithBoundsOp : TEST_Op<"with_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
NoSideEffect]> {
let arguments = (ins IndexAttr:$umin,
IndexAttr:$umax,
IndexAttr:$smin,
IndexAttr:$smax);
let results = (outs Index:$fakeVal);
let assemblyFormat = "attr-dict";
}
def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
SingleBlock, NoTerminator]> {
let arguments = (ins IndexAttr:$umin,
IndexAttr:$umax,
IndexAttr:$smin,
IndexAttr:$smax);
// The region has one argument of index type
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
}
def TestIncrementOp : TEST_Op<"increment",
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
NoSideEffect]> {
let arguments = (ins Index:$value);
let results = (outs Index:$result);
let assemblyFormat = "attr-dict $value";
}
def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
let arguments = (ins Index:$value,
OptionalAttr<IndexAttr>:$umin,
OptionalAttr<IndexAttr>:$umax,
OptionalAttr<IndexAttr>:$smin,
OptionalAttr<IndexAttr>:$smax);
let results = (outs Index:$result);
let assemblyFormat = "attr-dict $value";
}
#endif // TEST_OPS

View File

@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
TestIntRangeInference.cpp
EXCLUDE_FROM_LIBMLIR
@ -10,6 +11,8 @@ add_mlir_library(MLIRTestTransforms
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRInferIntRangeInterface
MLIRTestDialect
MLIRTransforms
)

View File

@ -0,0 +1,115 @@
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// TODO: This pass is needed to test integer range inference until that
// functionality has been integrated into SCCP.
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/IntRangeAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
/// Patterned after SCCP
static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
OpBuilder &b, OperationFolder &folder,
Value value) {
Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
if (!maybeInferredRange)
return failure();
const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
if (!maybeConstValue.hasValue())
return failure();
Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
value.getType(), value.getLoc());
if (!constant)
return failure();
value.replaceAllUsesWith(constant);
return success();
}
static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region &region : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};
OpBuilder builder(context);
OperationFolder folder(context);
addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
// Replace any result with constants.
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(analysis, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
op.erase();
continue;
}
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(analysis, builder, folder, arg);
}
}
namespace {
struct TestIntRangeInference
: PassWrapper<TestIntRangeInference, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
StringRef getArgument() const final { return "test-int-range-inference"; }
StringRef getDescription() const final {
return "Test integer range inference analysis";
}
void runOnOperation() override {
Operation *op = getOperation();
IntRangeAnalysis analysis(op);
rewrite(analysis, op->getContext(), op->getRegions());
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestIntRangeInference() {
PassRegistration<TestIntRangeInference>();
}
} // end namespace test
} // end namespace mlir

View File

@ -79,6 +79,7 @@ void registerTestDynamicPipelinePass();
void registerTestExpandMathPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
void registerTestGenericIRVisitorsInterruptPass();
@ -175,6 +176,7 @@ void registerTestPasses() {
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();

View File

@ -1,6 +1,7 @@
add_mlir_unittest(MLIRInterfacesTests
ControlFlowInterfacesTest.cpp
DataLayoutInterfacesTest.cpp
InferIntRangeInterfaceTest.cpp
InferTypeOpInterfaceTest.cpp
)
@ -10,6 +11,7 @@ target_link_libraries(MLIRInterfacesTests
MLIRDataLayoutInterfaces
MLIRDLTI
MLIRFunc
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRParser
)

View File

@ -0,0 +1,99 @@
//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/APInt.h"
#include <limits>
#include <gtest/gtest.h>
using namespace mlir;
TEST(IntRangeAttrs, BasicConstructors) {
APInt zero = APInt::getZero(64);
APInt two(64, 2);
APInt three(64, 3);
ConstantIntRanges boundedAbove(zero, two, zero, three);
EXPECT_EQ(boundedAbove.umin(), zero);
EXPECT_EQ(boundedAbove.umax(), two);
EXPECT_EQ(boundedAbove.smin(), zero);
EXPECT_EQ(boundedAbove.smax(), three);
}
TEST(IntRangeAttrs, FromUnsigned) {
APInt zero = APInt::getZero(64);
APInt maxInt = APInt::getSignedMaxValue(64);
APInt minInt = APInt::getSignedMinValue(64);
APInt minIntPlusOne = minInt + 1;
ConstantIntRanges canPortToSigned =
ConstantIntRanges::fromUnsigned(zero, maxInt);
EXPECT_EQ(canPortToSigned.smin(), zero);
EXPECT_EQ(canPortToSigned.smax(), maxInt);
ConstantIntRanges cantPortToSigned =
ConstantIntRanges::fromUnsigned(zero, minInt);
EXPECT_EQ(cantPortToSigned.smin(), minInt);
EXPECT_EQ(cantPortToSigned.smax(), maxInt);
ConstantIntRanges signedNegative =
ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne);
EXPECT_EQ(signedNegative.smin(), minInt);
EXPECT_EQ(signedNegative.smax(), minIntPlusOne);
}
TEST(IntRangeAttrs, FromSigned) {
APInt zero = APInt::getZero(64);
APInt one = zero + 1;
APInt negOne = zero - 1;
APInt intMax = APInt::getSignedMaxValue(64);
APInt intMin = APInt::getSignedMinValue(64);
APInt uintMax = APInt::getMaxValue(64);
ConstantIntRanges noUnsignedBound =
ConstantIntRanges::fromSigned(negOne, one);
EXPECT_EQ(noUnsignedBound.umin(), zero);
EXPECT_EQ(noUnsignedBound.umax(), uintMax);
ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax);
EXPECT_EQ(positive.umin(), one);
EXPECT_EQ(positive.umax(), intMax);
ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne);
EXPECT_EQ(negative.umin(), intMin);
EXPECT_EQ(negative.umax(), negOne);
ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one);
EXPECT_EQ(preserved.umin(), zero);
EXPECT_EQ(preserved.umax(), one);
}
TEST(IntRangeAttrs, Join) {
APInt zero = APInt::getZero(64);
APInt one = zero + 1;
APInt two = zero + 2;
APInt intMin = APInt::getSignedMinValue(64);
APInt intMax = APInt::getSignedMaxValue(64);
APInt uintMax = APInt::getMaxValue(64);
ConstantIntRanges maximal(zero, uintMax, intMin, intMax);
ConstantIntRanges zeroOne(zero, one, zero, one);
EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal);
EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal);
EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne);
ConstantIntRanges oneTwo(one, two, one, two);
ConstantIntRanges zeroTwo(zero, two, zero, two);
EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo);
ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax);
ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
}