forked from OSchip/llvm-project
[mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface
* Add hasUnitStride and hasZeroOffset to OffsetSizeAndStrideOpInterface. These functions are useful for various patterns. E.g., some vectorization patterns apply only for tensor ops with zero offsets and/or unit stride. * Add getConstantIntValue and isEqualConstantInt helper functions, which are useful for implementing the two above functions, as well as various patterns. Differential Revision: https://reviews.llvm.org/D103763
This commit is contained in:
parent
4f8bc7caf4
commit
6e7bbdd6e7
|
@ -122,6 +122,15 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
|||
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
|
||||
const APFloat &rhs);
|
||||
|
||||
/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
|
||||
/// IntegerAttr, return the integer.
|
||||
llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
|
||||
|
||||
/// Return true if ofr and value are the same integer.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
|
||||
|
|
|
@ -30,6 +30,8 @@ struct Range {
|
|||
|
||||
class OffsetSizeAndStrideOpInterface;
|
||||
|
||||
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
|
||||
|
||||
namespace detail {
|
||||
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
|
||||
|
||||
|
|
|
@ -436,6 +436,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
|||
$_op.getOperation()), other, cmp);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Return true if all strides are guaranteed to be 1. }],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasUnitStride",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
|
||||
return ::mlir::isEqualConstantInt(ofr, 1);
|
||||
});
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{ Return true if all offsets are guaranteed to be 0. }],
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasZeroOffset",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
|
||||
return ::mlir::isEqualConstantInt(ofr, 0);
|
||||
});
|
||||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -60,24 +60,35 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
|
|||
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
|
||||
}
|
||||
|
||||
/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
|
||||
/// IntegerAttr, return the integer.
|
||||
llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
// Note: isa+cast-like pattern allows writing the condition below as 1 line.
|
||||
if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
|
||||
attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
|
||||
if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Return true if ofr and value are the same integer.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
|
||||
auto ofrValue = getConstantIntValue(ofr);
|
||||
return ofrValue && *ofrValue == value;
|
||||
}
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType have no bitwidth.
|
||||
bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
|
||||
auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
// Note: isa+cast-like pattern allows writing the condition below as 1 line.
|
||||
if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
|
||||
attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
|
||||
if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return llvm::None;
|
||||
};
|
||||
auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
|
||||
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
||||
if (cst1 && cst2 && *cst1 == *cst2)
|
||||
return true;
|
||||
auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
|
||||
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
|
||||
return v1 && v2 && v1 == v2;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue