forked from OSchip/llvm-project
[mlir][NFC] Cleanup: Move helper functions to StaticValueUtils
Reduce code duplication: Move various helper functions, that are duplicated in TensorDialect, MemRefDialect, LinalgDialect, StandardDialect, into a new StaticValueUtils.cpp. Differential Revision: https://reviews.llvm.org/D104687
This commit is contained in:
parent
81f6d7c082
commit
0813700de1
|
@ -269,13 +269,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
// Return true if low padding is guaranteed to be 0.
|
||||
bool hasZeroLowPad() {
|
||||
return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) {
|
||||
return mlir::isEqualConstantInt(ofr, 0);
|
||||
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
|
||||
});
|
||||
}
|
||||
// Return true if high padding is guaranteed to be 0.
|
||||
bool hasZeroHighPad() {
|
||||
return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) {
|
||||
return mlir::isEqualConstantInt(ofr, 0);
|
||||
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
|
||||
});
|
||||
}
|
||||
}];
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
|
|
@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
|||
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
|
||||
const APFloat &rhs);
|
||||
|
||||
/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
|
||||
/// IntegerAttr, return the integer.
|
||||
llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
|
||||
|
||||
/// Return true if ofr and value are the same integer.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType have no bitwidth.
|
||||
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
|
||||
|
||||
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
||||
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
||||
OpBuilder &builder, Location loc);
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines utilities for dealing with static values, e.g.,
|
||||
// converting back and forth between Value and OpFoldResult. Such functionality
|
||||
// is used in multiple dialects.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
|
||||
#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel);
|
||||
|
||||
/// Helper function to dispatch multiple OpFoldResults into either the
|
||||
/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel);
|
||||
|
||||
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
|
||||
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
|
||||
|
||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType have no bitwidth.
|
||||
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
|
||||
#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
|
||||
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -30,8 +31,6 @@ struct Range {
|
|||
|
||||
class OffsetSizeAndStrideOpInterface;
|
||||
|
||||
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
|
||||
|
||||
namespace detail {
|
||||
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
|
||||
|
||||
|
|
|
@ -444,7 +444,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
|||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
|
||||
return ::mlir::isEqualConstantInt(ofr, 1);
|
||||
return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(1);
|
||||
});
|
||||
}]
|
||||
>,
|
||||
|
@ -456,7 +456,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
|
|||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
|
||||
return ::mlir::isEqualConstantInt(ofr, 0);
|
||||
return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(0);
|
||||
});
|
||||
}]
|
||||
>,
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -3388,14 +3389,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
|
||||
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
||||
return a.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
}
|
||||
|
||||
/// Conversion pattern that transforms a subview op into:
|
||||
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
|
||||
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
@ -116,24 +117,6 @@ static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
|
|||
}));
|
||||
}
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
static void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
if (auto v = ofr.dyn_cast<Value>()) {
|
||||
dynamicVec.push_back(v);
|
||||
staticVec.push_back(sentinel);
|
||||
return;
|
||||
}
|
||||
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
|
||||
staticVec.push_back(apInt.getSExtValue());
|
||||
}
|
||||
|
||||
/// This is a common class used for patterns of the form
|
||||
/// ```
|
||||
/// someop(memrefcast(%src)) -> someop(%src)
|
||||
|
@ -819,14 +802,6 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
|
|||
// PadTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
|
||||
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
||||
return a.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
}
|
||||
|
||||
static LogicalResult verify(PadTensorOp op) {
|
||||
auto sourceType = op.source().getType().cast<RankedTensorType>();
|
||||
auto resultType = op.result().getType().cast<RankedTensorType>();
|
||||
|
|
|
@ -110,6 +110,7 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
|
|
@ -814,8 +814,8 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
readInBounds.push_back(false);
|
||||
// Write is out-of-bounds if low padding > 0.
|
||||
writeInBounds.push_back(
|
||||
isEqualConstantIntOrValue(padOp.getMixedLowPad()[i],
|
||||
rewriter.getIndexAttr(0)));
|
||||
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
|
||||
static_cast<int64_t>(0));
|
||||
} else {
|
||||
// Neither source nor result dim of padOp is static. Cannot vectorize
|
||||
// the copy.
|
||||
|
@ -1098,9 +1098,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
|
|||
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
|
||||
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
|
||||
if (!llvm::all_of(
|
||||
llvm::zip(insertOp.getMixedSizes(), expectedSizes),
|
||||
[](auto it) { return isEqualConstantInt(std::get<0>(it),
|
||||
std::get<1>(it)); }))
|
||||
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
|
||||
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// Generate TransferReadOp: Read entire source tensor and add high padding.
|
||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRMemRef
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRDialectUtils
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRefUtils
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -32,40 +33,6 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
|
|||
return builder.create<mlir::ConstantOp>(loc, type, value);
|
||||
}
|
||||
|
||||
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
|
||||
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
||||
return a.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
}
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
static void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
if (auto v = ofr.dyn_cast<Value>()) {
|
||||
dynamicVec.push_back(v);
|
||||
staticVec.push_back(sentinel);
|
||||
return;
|
||||
}
|
||||
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
|
||||
staticVec.push_back(apInt.getSExtValue());
|
||||
}
|
||||
|
||||
static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
for (auto ofr : ofrs)
|
||||
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common canonicalization pattern support logic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -33,38 +33,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
|
||||
/// IntegerAttr, return the integer.
|
||||
llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
// Note: isa+cast-like pattern allows writing the condition below as 1 line.
|
||||
if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
|
||||
attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
|
||||
if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Return true if ofr and value are the same integer.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
|
||||
auto ofrValue = getConstantIntValue(ofr);
|
||||
return ofrValue && *ofrValue == value;
|
||||
}
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
|
||||
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
||||
if (cst1 && cst2 && *cst1 == *cst2)
|
||||
return true;
|
||||
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
|
||||
return v1 && v2 && v1 == v2;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StandardOpsDialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRTensor
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCastInterfaces
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -516,32 +517,6 @@ static LogicalResult verify(ReshapeOp op) {
|
|||
// ExtractSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
static void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
if (auto v = ofr.dyn_cast<Value>()) {
|
||||
dynamicVec.push_back(v);
|
||||
staticVec.push_back(sentinel);
|
||||
return;
|
||||
}
|
||||
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
|
||||
staticVec.push_back(apInt.getSExtValue());
|
||||
}
|
||||
|
||||
static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
for (auto ofr : ofrs)
|
||||
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
|
||||
}
|
||||
|
||||
/// An extract_slice op result type can be fully inferred from the source type
|
||||
/// and the static representation of offsets, sizes and strides. Special
|
||||
/// sentinels encode the dynamic case.
|
||||
|
@ -563,14 +538,6 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
|
|||
sourceRankedTensorType.getElementType());
|
||||
}
|
||||
|
||||
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
|
||||
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
||||
return a.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
}
|
||||
|
||||
Type ExtractSliceOp::inferResultType(
|
||||
RankedTensorType sourceRankedTensorType,
|
||||
ArrayRef<OpFoldResult> leadingStaticOffsets,
|
||||
|
@ -890,17 +857,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
|
|||
ShapedType shapedType) {
|
||||
OpBuilder b(op.getContext());
|
||||
for (OpFoldResult ofr : op.getMixedOffsets())
|
||||
if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
|
||||
if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
|
||||
return failure();
|
||||
// Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
|
||||
// is appropriate.
|
||||
auto shape = shapedType.getShape();
|
||||
for (auto it : llvm::zip(op.getMixedSizes(), shape))
|
||||
if (!isEqualConstantIntOrValue(std::get<0>(it),
|
||||
b.getIndexAttr(std::get<1>(it))))
|
||||
if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
|
||||
return failure();
|
||||
for (OpFoldResult ofr : op.getMixedStrides())
|
||||
if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
|
||||
if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_library(MLIRDialectUtils
|
||||
StructuredOpsUtils.cpp
|
||||
StaticValueUtils.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
|
||||
/// it is a Value or into `staticVec` if it is an IntegerAttr.
|
||||
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
|
||||
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
|
||||
/// come from an AttrSizedOperandSegments trait.
|
||||
void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
if (auto v = ofr.dyn_cast<Value>()) {
|
||||
dynamicVec.push_back(v);
|
||||
staticVec.push_back(sentinel);
|
||||
return;
|
||||
}
|
||||
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
|
||||
staticVec.push_back(apInt.getSExtValue());
|
||||
}
|
||||
|
||||
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec,
|
||||
int64_t sentinel) {
|
||||
for (OpFoldResult ofr : ofrs)
|
||||
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
|
||||
}
|
||||
|
||||
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
|
||||
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
||||
return a.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
}
|
||||
|
||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||
Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||
// Case 1: Check for Constant integer.
|
||||
if (auto val = ofr.dyn_cast<Value>()) {
|
||||
APSInt intVal;
|
||||
if (matchPattern(val, m_ConstantInt(&intVal)))
|
||||
return intVal.getSExtValue();
|
||||
return llvm::None;
|
||||
}
|
||||
// Case 2: Check for IntegerAttr.
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
|
||||
/// or the same SSA value.
|
||||
/// Ignore integer bitwidth and type mismatch that come from the fact there is
|
||||
/// no IndexAttr and that IndexType has no bitwidth.
|
||||
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
|
||||
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
||||
if (cst1 && cst2 && *cst1 == *cst2)
|
||||
return true;
|
||||
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
|
||||
return v1 && v1 == v2;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
Loading…
Reference in New Issue