forked from OSchip/llvm-project
Add a constant folding hook to ExtractElementOp to fold extracting the element of a constant. This also adds a 'getValue' function to DenseElementsAttr and SparseElementsAttr to get the element at a constant index.
PiperOrigin-RevId: 230098938
This commit is contained in:
parent
119af6712e
commit
512d87cefc
|
@ -348,6 +348,10 @@ public:
|
||||||
static DenseElementsAttr get(VectorOrTensorType type,
|
static DenseElementsAttr get(VectorOrTensorType type,
|
||||||
ArrayRef<Attribute> values);
|
ArrayRef<Attribute> values);
|
||||||
|
|
||||||
|
/// Return the value at the given index. If index does not refer to a valid
|
||||||
|
/// element, then a null attribute is returned.
|
||||||
|
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||||
|
|
||||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
|
||||||
ArrayRef<char> getRawData() const;
|
ArrayRef<char> getRawData() const;
|
||||||
|
@ -427,10 +431,11 @@ public:
|
||||||
/// This class uses COO (coordinate list) encoding to represent the sparse
|
/// This class uses COO (coordinate list) encoding to represent the sparse
|
||||||
/// elements in an element attribute. Specifically, the sparse vector/tensor
|
/// elements in an element attribute. Specifically, the sparse vector/tensor
|
||||||
/// stores the indices and values as two separate dense elements attributes. The
|
/// stores the indices and values as two separate dense elements attributes. The
|
||||||
/// dense elements attribute indices is a 2-D tensor with shape [N, ndims],
|
/// dense elements attribute indices is a 2-D tensor of 64-bit integer elements
|
||||||
/// which specifies the indices of the elements in the sparse tensor that
|
/// with shape [N, ndims], which specifies the indices of the elements in the
|
||||||
/// contains nonzero values. The dense elements attribute values is a 1-D tensor
|
/// sparse tensor that contains nonzero values. The dense elements attribute
|
||||||
/// with shape [N], and it supplies the corresponding values for the indices.
|
/// values is a 1-D tensor with shape [N], and it supplies the corresponding
|
||||||
|
/// values for the indices.
|
||||||
///
|
///
|
||||||
/// For example,
|
/// For example,
|
||||||
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
|
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
|
||||||
|
@ -450,8 +455,8 @@ public:
|
||||||
|
|
||||||
DenseElementsAttr getValues() const;
|
DenseElementsAttr getValues() const;
|
||||||
|
|
||||||
/// Return the value at the given index.
|
/// Return the value of the element at the given index.
|
||||||
Attribute getValue(ArrayRef<unsigned> index) const;
|
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
|
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
|
||||||
|
|
|
@ -523,6 +523,8 @@ public:
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||||
void print(OpAsmPrinter *p) const;
|
void print(OpAsmPrinter *p) const;
|
||||||
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class OperationInst;
|
friend class OperationInst;
|
||||||
|
|
|
@ -149,6 +149,53 @@ Attribute SplatElementsAttr::getValue() const {
|
||||||
|
|
||||||
/// DenseElementsAttr
|
/// DenseElementsAttr
|
||||||
|
|
||||||
|
/// Return the value at the given index. If index does not refer to a valid
|
||||||
|
/// element, then a null attribute is returned.
|
||||||
|
Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||||
|
auto type = getType();
|
||||||
|
|
||||||
|
// Verify that the rank of the indices matches the held type.
|
||||||
|
auto rank = type.getRank();
|
||||||
|
if (rank != index.size())
|
||||||
|
return Attribute();
|
||||||
|
|
||||||
|
// Verify that all of the indices are within the shape dimensions.
|
||||||
|
auto shape = type.getShape();
|
||||||
|
for (unsigned i = 0; i != rank; ++i)
|
||||||
|
if (shape[i] <= index[i])
|
||||||
|
return Attribute();
|
||||||
|
|
||||||
|
// Reduce the provided multidimensional index into a 1D index.
|
||||||
|
uint64_t valueIndex = 0;
|
||||||
|
uint64_t dimMultiplier = 1;
|
||||||
|
for (int i = rank - 1; i >= 0; --i) {
|
||||||
|
valueIndex += index[i] * dimMultiplier;
|
||||||
|
dimMultiplier *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the element stored at the 1D index.
|
||||||
|
|
||||||
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||||
|
// with double semantics.
|
||||||
|
auto elementType = getType().getElementType();
|
||||||
|
size_t bitWidth =
|
||||||
|
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||||
|
APInt rawValueData =
|
||||||
|
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
|
||||||
|
|
||||||
|
// Convert the raw value data to an attribute value.
|
||||||
|
switch (getKind()) {
|
||||||
|
case Attribute::Kind::DenseIntElements:
|
||||||
|
return IntegerAttr::get(elementType, rawValueData);
|
||||||
|
case Attribute::Kind::DenseFPElements:
|
||||||
|
return FloatAttr::get(
|
||||||
|
elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
|
||||||
|
rawValueData));
|
||||||
|
default:
|
||||||
|
llvm_unreachable("unexpected element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
auto elementType = getType().getElementType();
|
auto elementType = getType().getElementType();
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
|
@ -188,8 +235,8 @@ void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
||||||
auto elementNum = getType().getNumElements();
|
auto elementNum = getType().getNumElements();
|
||||||
values.reserve(elementNum);
|
values.reserve(elementNum);
|
||||||
|
|
||||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||||
// semantics.
|
// with double semantics.
|
||||||
size_t bitWidth =
|
size_t bitWidth =
|
||||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||||
const auto *rawData = getRawData().data();
|
const auto *rawData = getRawData().data();
|
||||||
|
@ -281,3 +328,40 @@ DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
||||||
DenseElementsAttr SparseElementsAttr::getValues() const {
|
DenseElementsAttr SparseElementsAttr::getValues() const {
|
||||||
return static_cast<ImplType *>(attr)->values;
|
return static_cast<ImplType *>(attr)->values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the value of the element at the given index.
|
||||||
|
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||||
|
auto type = getType();
|
||||||
|
|
||||||
|
// Verify that the rank of the indices matches the held type.
|
||||||
|
auto rank = type.getRank();
|
||||||
|
if (rank != index.size())
|
||||||
|
return Attribute();
|
||||||
|
|
||||||
|
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||||
|
// as a 1-D index array.
|
||||||
|
auto sparseIndices = getIndices();
|
||||||
|
const uint64_t *sparseIndexValues =
|
||||||
|
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
|
||||||
|
|
||||||
|
// Build a mapping between known indices and the offset of the stored element.
|
||||||
|
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
||||||
|
size_t numSparseIndices = sparseIndices.getType().getDimSize(0);
|
||||||
|
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
||||||
|
mappedIndices.try_emplace(
|
||||||
|
{sparseIndexValues + (i * rank), static_cast<size_t>(rank)}, i);
|
||||||
|
|
||||||
|
// Look for the provided index key within the mapped indices. If the provided
|
||||||
|
// index is not found, then return a zero attribute.
|
||||||
|
auto it = mappedIndices.find(index);
|
||||||
|
if (it == mappedIndices.end()) {
|
||||||
|
auto eltType = type.getElementType();
|
||||||
|
if (eltType.isa<FloatType>())
|
||||||
|
return FloatAttr::get(eltType, 0);
|
||||||
|
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
||||||
|
return IntegerAttr::get(eltType, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, return the held sparse value element.
|
||||||
|
return getValues().getValue(it->second);
|
||||||
|
}
|
||||||
|
|
|
@ -1074,8 +1074,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||||
assert(values.size() == type.getNumElements() &&
|
assert(values.size() == type.getNumElements() &&
|
||||||
"expected 'values' to contain the same number of elements as 'type'");
|
"expected 'values' to contain the same number of elements as 'type'");
|
||||||
|
|
||||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||||
// semantics.
|
// with double semantics.
|
||||||
auto eltType = type.getElementType();
|
auto eltType = type.getElementType();
|
||||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
@ -1136,6 +1136,9 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
|
||||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
||||||
DenseIntElementsAttr indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr values) {
|
DenseElementsAttr values) {
|
||||||
|
assert(indices.getType().getElementType().isInteger(64) &&
|
||||||
|
"expected sparse indices to be 64-bit integer values");
|
||||||
|
|
||||||
auto &impl = type.getContext()->getImpl();
|
auto &impl = type.getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if we already have this.
|
||||||
|
|
|
@ -1076,7 +1076,7 @@ Attribute Parser::parseAttribute(Type type) {
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
case Token::l_square: {
|
case Token::l_square: {
|
||||||
/// Parse indices
|
/// Parse indices
|
||||||
auto indicesEltType = builder.getIntegerType(32);
|
auto indicesEltType = builder.getIntegerType(64);
|
||||||
auto indices =
|
auto indices =
|
||||||
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
|
parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
|
||||||
if (!indices)
|
if (!indices)
|
||||||
|
|
|
@ -1136,6 +1136,41 @@ bool ExtractElementOp::verify() const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Attribute ExtractElementOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
|
MLIRContext *context) const {
|
||||||
|
assert(operands.size() > 1 && "extract_element takes atleast one operands");
|
||||||
|
|
||||||
|
// The aggregate operand must be a known constant.
|
||||||
|
Attribute aggregate = operands.front();
|
||||||
|
if (!aggregate)
|
||||||
|
return Attribute();
|
||||||
|
|
||||||
|
// If this is a splat elements attribute, simply return the value. All of the
|
||||||
|
// elements of a splat attribute are the same.
|
||||||
|
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
|
||||||
|
return splatAggregate.getValue();
|
||||||
|
|
||||||
|
// Otherwise, collect the constant indices into the aggregate.
|
||||||
|
SmallVector<uint64_t, 8> indices;
|
||||||
|
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
||||||
|
if (!indice || !indice.isa<IntegerAttr>())
|
||||||
|
return Attribute();
|
||||||
|
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the element value of the aggregate attribute with the given constant
|
||||||
|
// indices.
|
||||||
|
switch (aggregate.getKind()) {
|
||||||
|
case Attribute::Kind::DenseFPElements:
|
||||||
|
case Attribute::Kind::DenseIntElements:
|
||||||
|
return aggregate.cast<DenseElementsAttr>().getValue(indices);
|
||||||
|
case Attribute::Kind::SparseElements:
|
||||||
|
return aggregate.cast<SparseElementsAttr>().getValue(indices);
|
||||||
|
default:
|
||||||
|
return Attribute();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LoadOp
|
// LoadOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -320,6 +320,36 @@ func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
|
||||||
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
|
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_extract_element
|
||||||
|
func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
|
||||||
|
%const_0 = constant 0 : index
|
||||||
|
%const_1 = constant 1 : index
|
||||||
|
%const_3 = constant 3 : index
|
||||||
|
|
||||||
|
// Fold an extract into a splat.
|
||||||
|
// CHECK-NEXT: {{.*}} = constant 4.500000e+00 : f32
|
||||||
|
%0 = constant splat<tensor<4xf32>, 4.5> : tensor<4xf32>
|
||||||
|
%ext_1 = extract_element %0[%arg0] : tensor<4xf32>
|
||||||
|
|
||||||
|
// Fold an extract into a sparse with a sparse index.
|
||||||
|
// CHECK-NEXT: {{.*}} = constant -2.000000e+00 : f16
|
||||||
|
%1 = constant sparse<vector<1x1x1xf16>, [[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : vector<1x1x1xf16>
|
||||||
|
%ext_2 = extract_element %1[%const_1, %const_1, %const_1] : vector<1x1x1xf16>
|
||||||
|
|
||||||
|
// Fold an extract into a sparse with a non sparse index.
|
||||||
|
// CHECK-NEXT: {{.*}} = constant 0.000000e+00 : f16
|
||||||
|
%2 = constant sparse<vector<1x1x1xf16>, [[1, 1, 1]], [-2.0]> : vector<1x1x1xf16>
|
||||||
|
%ext_3 = extract_element %2[%const_0, %const_0, %const_0] : vector<1x1x1xf16>
|
||||||
|
|
||||||
|
// Fold an extract into a dense tensor.
|
||||||
|
// CHECK-NEXT: {{.*}} = constant 64 : i32
|
||||||
|
%3 = constant dense<tensor<2x1x4xi32>, [[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
|
||||||
|
%ext_4 = extract_element %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||||
|
}
|
||||||
|
|
||||||
// --------------------------------------------------------------------------//
|
// --------------------------------------------------------------------------//
|
||||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||||
// lowering affine_apply (i) -> (i mod 42) to standard operations. Please only
|
// lowering affine_apply (i) -> (i mod 42) to standard operations. Please only
|
||||||
|
|
Loading…
Reference in New Issue