2018-10-26 06:46:10 +08:00
|
|
|
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
|
|
#include "AttributeDetail.h"
|
2018-10-26 13:13:03 +08:00
|
|
|
#include "mlir/IR/AffineMap.h"
|
2019-02-12 14:51:34 +08:00
|
|
|
#include "mlir/IR/Dialect.h"
|
2018-10-26 06:46:10 +08:00
|
|
|
#include "mlir/IR/Function.h"
|
2018-10-26 13:13:03 +08:00
|
|
|
#include "mlir/IR/IntegerSet.h"
|
2018-10-26 06:46:10 +08:00
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::detail;
|
|
|
|
|
|
|
|
Attribute::Kind Attribute::getKind() const { return attr->kind; }
|
|
|
|
|
|
|
|
bool Attribute::isOrContainsFunction() const {
|
|
|
|
return attr->isOrContainsFunctionCache;
|
|
|
|
}
|
|
|
|
|
2018-11-15 06:26:47 +08:00
|
|
|
// Given an attribute that could refer to a function attribute in the remapping
|
|
|
|
// table, walk it and rewrite it to use the mapped function. If it doesn't
|
|
|
|
// refer to anything in the table, then it is returned unmodified.
|
|
|
|
Attribute Attribute::remapFunctionAttrs(
|
|
|
|
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
|
|
|
|
MLIRContext *context) const {
|
|
|
|
// Most attributes are trivially unrelated to function attributes, skip them
|
|
|
|
// rapidly.
|
|
|
|
if (!isOrContainsFunction())
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
// If we have a function attribute, remap it.
|
|
|
|
if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
|
|
|
|
auto it = remappingTable.find(fnAttr);
|
|
|
|
return it != remappingTable.end() ? it->second : *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, we must have an array attribute, remap the elements.
|
|
|
|
auto arrayAttr = this->cast<ArrayAttr>();
|
|
|
|
SmallVector<Attribute, 8> remappedElts;
|
|
|
|
bool anyChange = false;
|
|
|
|
for (auto elt : arrayAttr.getValue()) {
|
|
|
|
auto newElt = elt.remapFunctionAttrs(remappingTable, context);
|
|
|
|
remappedElts.push_back(newElt);
|
|
|
|
anyChange |= (elt != newElt);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!anyChange)
|
|
|
|
return *this;
|
|
|
|
|
|
|
|
return ArrayAttr::get(remappedElts, context);
|
|
|
|
}
|
|
|
|
|
2019-02-03 22:15:43 +08:00
|
|
|
/// NumericAttr
|
|
|
|
|
|
|
|
Type NumericAttr::getType() const {
|
|
|
|
if (auto boolAttr = dyn_cast<BoolAttr>())
|
|
|
|
return boolAttr.getType();
|
|
|
|
if (auto intAttr = dyn_cast<IntegerAttr>())
|
|
|
|
return intAttr.getType();
|
|
|
|
if (auto floatAttr = dyn_cast<FloatAttr>())
|
|
|
|
return floatAttr.getType();
|
|
|
|
if (auto elemAttr = dyn_cast<ElementsAttr>())
|
|
|
|
return elemAttr.getType();
|
|
|
|
|
|
|
|
llvm_unreachable("unhandled NumericAttr subclass");
|
|
|
|
}
|
|
|
|
|
|
|
|
bool NumericAttr::kindof(Kind kind) {
|
|
|
|
return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) ||
|
|
|
|
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// BoolAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
|
|
|
|
2019-02-03 22:15:43 +08:00
|
|
|
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// IntegerAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-11-12 22:33:22 +08:00
|
|
|
APInt IntegerAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->getValue();
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2018-11-12 22:33:22 +08:00
|
|
|
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
Type IntegerAttr::getType() const {
|
|
|
|
return static_cast<ImplType *>(attr)->type;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// FloatAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
APFloat FloatAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->getValue();
|
|
|
|
}
|
|
|
|
|
2018-11-16 09:53:51 +08:00
|
|
|
Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
|
|
|
|
2018-12-28 08:51:09 +08:00
|
|
|
double FloatAttr::getValueAsDouble() const {
|
|
|
|
const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
|
|
|
|
auto value = getValue();
|
|
|
|
bool losesInfo = false; // ignored
|
|
|
|
if (&semantics != &APFloat::IEEEdouble()) {
|
|
|
|
value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
|
|
|
|
&losesInfo);
|
|
|
|
}
|
|
|
|
return value.convertToDouble();
|
|
|
|
}
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// StringAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
StringRef StringAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// ArrayAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
ArrayRef<Attribute> ArrayAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// AffineMapAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
AffineMap AffineMapAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// IntegerSetAttr
|
2018-10-26 13:13:03 +08:00
|
|
|
|
|
|
|
IntegerSet IntegerSetAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// TypeAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// FunctionAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
Function *FunctionAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->value;
|
|
|
|
}
|
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// ElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2018-10-31 05:59:22 +08:00
|
|
|
VectorOrTensorType ElementsAttr::getType() const {
|
2018-10-26 06:46:10 +08:00
|
|
|
return static_cast<ImplType *>(attr)->type;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// SplatElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
Attribute SplatElementsAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->elt;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-20 12:54:09 +08:00
|
|
|
/// Return the value at the given index. If index does not refer to a valid
|
|
|
|
/// element, then a null attribute is returned.
|
|
|
|
Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
auto type = getType();
|
|
|
|
|
|
|
|
// Verify that the rank of the indices matches the held type.
|
|
|
|
auto rank = type.getRank();
|
|
|
|
if (rank != index.size())
|
|
|
|
return Attribute();
|
|
|
|
|
|
|
|
// Verify that all of the indices are within the shape dimensions.
|
|
|
|
auto shape = type.getShape();
|
|
|
|
for (unsigned i = 0; i != rank; ++i)
|
|
|
|
if (shape[i] <= index[i])
|
|
|
|
return Attribute();
|
|
|
|
|
|
|
|
// Reduce the provided multidimensional index into a 1D index.
|
|
|
|
uint64_t valueIndex = 0;
|
|
|
|
uint64_t dimMultiplier = 1;
|
2019-01-24 06:39:45 +08:00
|
|
|
for (auto i = rank - 1; i >= 0; --i) {
|
2019-01-20 12:54:09 +08:00
|
|
|
valueIndex += index[i] * dimMultiplier;
|
|
|
|
dimMultiplier *= shape[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Return the element stored at the 1D index.
|
|
|
|
|
|
|
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
|
|
|
// with double semantics.
|
|
|
|
auto elementType = getType().getElementType();
|
|
|
|
size_t bitWidth =
|
|
|
|
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
|
|
|
APInt rawValueData =
|
|
|
|
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
|
|
|
|
|
|
|
|
// Convert the raw value data to an attribute value.
|
|
|
|
switch (getKind()) {
|
|
|
|
case Attribute::Kind::DenseIntElements:
|
|
|
|
return IntegerAttr::get(elementType, rawValueData);
|
|
|
|
case Attribute::Kind::DenseFPElements:
|
|
|
|
return FloatAttr::get(
|
|
|
|
elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
|
|
|
|
rawValueData));
|
|
|
|
default:
|
|
|
|
llvm_unreachable("unexpected element type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
2019-01-18 06:11:05 +08:00
|
|
|
auto elementType = getType().getElementType();
|
2018-10-26 06:46:10 +08:00
|
|
|
switch (getKind()) {
|
2019-01-18 06:11:05 +08:00
|
|
|
case Attribute::Kind::DenseIntElements: {
|
|
|
|
// Get the raw APInt values.
|
|
|
|
SmallVector<APInt, 8> intValues;
|
|
|
|
cast<DenseIntElementsAttr>().getValues(intValues);
|
|
|
|
|
|
|
|
// Convert each to an IntegerAttr.
|
|
|
|
for (auto &intVal : intValues)
|
|
|
|
values.push_back(IntegerAttr::get(elementType, intVal));
|
2018-10-26 06:46:10 +08:00
|
|
|
return;
|
2019-01-18 06:11:05 +08:00
|
|
|
}
|
|
|
|
case Attribute::Kind::DenseFPElements: {
|
|
|
|
// Get the raw APFloat values.
|
|
|
|
SmallVector<APFloat, 8> floatValues;
|
|
|
|
cast<DenseFPElementsAttr>().getValues(floatValues);
|
|
|
|
|
|
|
|
// Convert each to an FloatAttr.
|
|
|
|
for (auto &floatVal : floatValues)
|
|
|
|
values.push_back(FloatAttr::get(elementType, floatVal));
|
2018-10-26 06:46:10 +08:00
|
|
|
return;
|
2019-01-18 06:11:05 +08:00
|
|
|
}
|
2018-10-26 06:46:10 +08:00
|
|
|
default:
|
|
|
|
llvm_unreachable("unexpected element type");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|
|
|
return static_cast<ImplType *>(attr)->data;
|
|
|
|
}
|
|
|
|
|
2019-01-18 06:11:05 +08:00
|
|
|
/// Parses the raw integer internal value for each dense element into
|
|
|
|
/// 'values'.
|
|
|
|
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
|
|
|
auto elementType = getType().getElementType();
|
|
|
|
auto elementNum = getType().getNumElements();
|
|
|
|
values.reserve(elementNum);
|
|
|
|
|
2019-01-20 12:54:09 +08:00
|
|
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
|
|
|
// with double semantics.
|
2019-01-18 06:11:05 +08:00
|
|
|
size_t bitWidth =
|
|
|
|
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
|
|
|
const auto *rawData = getRawData().data();
|
|
|
|
for (size_t i = 0, e = elementNum; i != e; ++i)
|
|
|
|
values.push_back(readBits(rawData, i * bitWidth, bitWidth));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
|
|
|
|
/// expected to be a 64-bit aligned storage address.
|
|
|
|
void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
|
|
|
|
size_t bitWidth = value.getBitWidth();
|
|
|
|
|
|
|
|
// If the bitwidth is 1 we just toggle the specific bit.
|
|
|
|
if (bitWidth == 1) {
|
|
|
|
auto *rawIntData = reinterpret_cast<uint64_t *>(rawData);
|
|
|
|
if (value.isOneValue())
|
|
|
|
APInt::tcSetBit(rawIntData, bitPos);
|
|
|
|
else
|
|
|
|
APInt::tcClearBit(rawIntData, bitPos);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the bit position and width are byte aligned, write the storage directly
|
|
|
|
// to the data.
|
|
|
|
if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) {
|
|
|
|
std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
|
|
|
|
bitWidth / 8, rawData + (bitPos / 8));
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, convert the raw data into an APInt and insert the value at the
|
|
|
|
// specified bit position.
|
|
|
|
size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth);
|
|
|
|
llvm::MutableArrayRef<uint64_t> rawIntData(
|
|
|
|
reinterpret_cast<uint64_t *>(rawData) + (bitPos / 64), totalWords);
|
|
|
|
APInt tempStorage(totalWords * 64, rawIntData);
|
|
|
|
tempStorage.insertBits(value, bitPos % 64);
|
|
|
|
|
|
|
|
// Copy the value back to the raw data.
|
|
|
|
std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data());
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2019-01-18 06:11:05 +08:00
|
|
|
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
|
|
|
/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
|
|
|
|
APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|
|
|
size_t bitWidth) {
|
|
|
|
// Reinterpret the raw data as a uint64_t word array and extract the value
|
|
|
|
// starting at 'bitPos'.
|
|
|
|
APInt result(bitWidth, 0);
|
|
|
|
const uint64_t *intData = reinterpret_cast<const uint64_t *>(rawData);
|
|
|
|
APInt::tcExtract(const_cast<uint64_t *>(result.getRawData()),
|
|
|
|
result.getNumWords(), intData, bitWidth, bitPos);
|
|
|
|
return result;
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseIntElementsAttr
|
2018-12-18 21:25:17 +08:00
|
|
|
|
2019-01-18 06:11:05 +08:00
|
|
|
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
|
|
|
// Simply return the raw integer values.
|
|
|
|
getRawValues(values);
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// DenseFPElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
2019-01-18 06:11:05 +08:00
|
|
|
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
|
|
|
// Get the raw APInt element values.
|
|
|
|
SmallVector<APInt, 8> intValues;
|
|
|
|
getRawValues(intValues);
|
2018-12-18 21:25:17 +08:00
|
|
|
|
2019-01-18 06:11:05 +08:00
|
|
|
// Convert each of the APInt values to an APFloat.
|
2018-12-18 21:25:17 +08:00
|
|
|
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
2019-01-18 06:11:05 +08:00
|
|
|
const auto &elementSemantics = elementType.getFloatSemantics();
|
|
|
|
for (auto &intValue : intValues)
|
|
|
|
values.push_back(APFloat(elementSemantics, intValue));
|
2018-10-26 06:46:10 +08:00
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// OpaqueElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
StringRef OpaqueElementsAttr::getValue() const {
|
|
|
|
return static_cast<ImplType *>(attr)->bytes;
|
|
|
|
}
|
|
|
|
|
2019-02-12 14:51:34 +08:00
|
|
|
Dialect *OpaqueElementsAttr::getDialect() const {
|
|
|
|
return static_cast<ImplType *>(attr)->dialect;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|
|
|
if (auto *d = getDialect())
|
|
|
|
return d->decodeHook(*this, result);
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2019-01-12 04:33:12 +08:00
|
|
|
/// SparseElementsAttr
|
2018-10-26 06:46:10 +08:00
|
|
|
|
|
|
|
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
|
|
|
return static_cast<ImplType *>(attr)->indices;
|
|
|
|
}
|
|
|
|
|
|
|
|
DenseElementsAttr SparseElementsAttr::getValues() const {
|
|
|
|
return static_cast<ImplType *>(attr)->values;
|
|
|
|
}
|
2019-01-20 12:54:09 +08:00
|
|
|
|
|
|
|
/// Return the value of the element at the given index.
|
|
|
|
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|
|
|
auto type = getType();
|
|
|
|
|
|
|
|
// Verify that the rank of the indices matches the held type.
|
|
|
|
auto rank = type.getRank();
|
|
|
|
if (rank != index.size())
|
|
|
|
return Attribute();
|
|
|
|
|
|
|
|
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
|
|
|
// as a 1-D index array.
|
|
|
|
auto sparseIndices = getIndices();
|
|
|
|
const uint64_t *sparseIndexValues =
|
|
|
|
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
|
|
|
|
|
|
|
|
// Build a mapping between known indices and the offset of the stored element.
|
|
|
|
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
2019-01-24 06:39:45 +08:00
|
|
|
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
|
2019-01-20 12:54:09 +08:00
|
|
|
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
|
|
|
mappedIndices.try_emplace(
|
|
|
|
{sparseIndexValues + (i * rank), static_cast<size_t>(rank)}, i);
|
|
|
|
|
|
|
|
// Look for the provided index key within the mapped indices. If the provided
|
|
|
|
// index is not found, then return a zero attribute.
|
|
|
|
auto it = mappedIndices.find(index);
|
|
|
|
if (it == mappedIndices.end()) {
|
|
|
|
auto eltType = type.getElementType();
|
|
|
|
if (eltType.isa<FloatType>())
|
|
|
|
return FloatAttr::get(eltType, 0);
|
|
|
|
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
|
|
|
return IntegerAttr::get(eltType, 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, return the held sparse value element.
|
|
|
|
return getValues().getValue(it->second);
|
|
|
|
}
|