forked from OSchip/llvm-project
Add support to constant splat vector/tensor attribute.
This attribute represents a reference to a splat vector or tensor, where all the elements have the same value. The syntax of the attribute is: `splat<` (tensor-type | vector-type)`,` attribute-value `>` PiperOrigin-RevId: 216537997
This commit is contained in:
parent
fd06c6bc4e
commit
5e3cca906a
|
@ -27,6 +27,7 @@ class Function;
|
|||
class FunctionType;
|
||||
class MLIRContext;
|
||||
class Type;
|
||||
class VectorOrTensorType;
|
||||
|
||||
/// Attributes are known-constant values of operations and functions.
|
||||
///
|
||||
|
@ -43,6 +44,10 @@ public:
|
|||
Array,
|
||||
AffineMap,
|
||||
Function,
|
||||
|
||||
SplatElements,
|
||||
FIRST_ELEMENTS_ATTR = SplatElements,
|
||||
LAST_ELEMENTS_ATTR = SplatElements,
|
||||
};
|
||||
|
||||
/// Return the classification for this attribute.
|
||||
|
@ -249,6 +254,41 @@ private:
|
|||
Function *value;
|
||||
};
|
||||
|
||||
/// A base attribute represents a reference to a vector or tensor constant.
|
||||
class ElementsAttr : public Attribute {
|
||||
public:
|
||||
ElementsAttr(Kind kind, VectorOrTensorType *type)
|
||||
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {}
|
||||
|
||||
VectorOrTensorType *getType() const { return type; }
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR &&
|
||||
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR;
|
||||
}
|
||||
|
||||
private:
|
||||
VectorOrTensorType *type;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a splat vecctor or tensor constant,
|
||||
/// meaning all of the elements have the same value.
|
||||
class SplatElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
static ElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
|
||||
Attribute *getValue() const { return elt; }
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::SplatElements;
|
||||
}
|
||||
|
||||
private:
|
||||
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
|
||||
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
|
||||
Attribute *elt;
|
||||
};
|
||||
} // end namespace mlir.
|
||||
|
||||
#endif
|
||||
|
|
|
@ -43,6 +43,7 @@ class StringAttr;
|
|||
class TypeAttr;
|
||||
class ArrayAttr;
|
||||
class FunctionAttr;
|
||||
class ElementsAttr;
|
||||
class AffineMapAttr;
|
||||
class AffineMap;
|
||||
|
||||
|
@ -98,6 +99,7 @@ public:
|
|||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
||||
TypeAttr *getTypeAttr(Type *type);
|
||||
FunctionAttr *getFunctionAttr(const Function *value);
|
||||
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
|
||||
|
||||
// Affine expressions and affine maps.
|
||||
AffineExpr getAffineDimExpr(unsigned position);
|
||||
|
|
|
@ -303,10 +303,7 @@ public:
|
|||
|
||||
/// If any dimension has unknown size (<0), it doesn't have static shape.
|
||||
/// If all dimensions has known size (>= 0), it has static shape.
|
||||
bool hasStaticShape() const {
|
||||
auto dims = getShape();
|
||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
bool hasStaticShape() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the size of the specified
|
||||
/// dimension. It aborts if the tensor is unranked (this can be checked by
|
||||
|
|
|
@ -456,6 +456,15 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::SplatElements: {
|
||||
auto *elementsAttr = cast<SplatElementsAttr>(attr);
|
||||
os << "splat<";
|
||||
printType(elementsAttr->getType());
|
||||
os << ", ";
|
||||
printAttribute(elementsAttr->getValue());
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -144,6 +144,11 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
|||
return FunctionAttr::get(value, context);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
||||
Attribute *elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Affine Expressions, Affine Maps, and Integet Sets.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -274,6 +274,8 @@ public:
|
|||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||
AttributeListSet attributeLists;
|
||||
DenseMap<const Function *, FunctionAttr *> functionAttrs;
|
||||
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
|
||||
splatElementsAttrs;
|
||||
|
||||
public:
|
||||
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {
|
||||
|
@ -775,7 +777,6 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
}
|
||||
}
|
||||
|
||||
// Ok, now that we've canonicalized our attributes, unique them.
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
|
@ -797,6 +798,23 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
// Otherwise, allocate them into the bump pointer.
|
||||
result = impl.allocator.Allocate<SplatElementsAttr>();
|
||||
new (result) SplatElementsAttr(type, elt);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineMap and AffineExpr uniquing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -47,9 +47,8 @@ int VectorOrTensorType::getRank() const {
|
|||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getShape().size();
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape().size();
|
||||
return getShape().size();
|
||||
case Kind::UnrankedTensor:
|
||||
return -1;
|
||||
}
|
||||
|
@ -58,14 +57,31 @@ int VectorOrTensorType::getRank() const {
|
|||
int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getShape()[i];
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape()[i];
|
||||
return getShape()[i];
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<int> VectorOrTensorType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getShape();
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape();
|
||||
case Kind::UnrankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape();
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
bool VectorOrTensorType::hasStaticShape() const {
|
||||
auto dims = getShape();
|
||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
||||
|
|
|
@ -658,6 +658,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
|||
/// | type
|
||||
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
|
||||
/// | function-id `:` function-type
|
||||
/// | `splat<` (tensor-type | vector-type)`,`
|
||||
/// attribute-value `>`
|
||||
///
|
||||
Attribute *Parser::parseAttribute() {
|
||||
switch (getToken().getKind()) {
|
||||
|
@ -752,6 +754,42 @@ Attribute *Parser::parseAttribute() {
|
|||
return function ? builder.getFunctionAttr(function) : nullptr;
|
||||
}
|
||||
|
||||
case Token::kw_splat: {
|
||||
consumeToken(Token::kw_splat);
|
||||
if (parseToken(Token::less, "Expected '<' after 'elements'"))
|
||||
return nullptr;
|
||||
|
||||
auto *type = dyn_cast<VectorOrTensorType>(parseType());
|
||||
if (!type) {
|
||||
return (
|
||||
emitError("expected elements literal has a tensor or vector type"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
if (parseToken(Token::comma, "Expected ','"))
|
||||
return nullptr;
|
||||
|
||||
if (!type->hasStaticShape() || type->getRank() == -1) {
|
||||
return (emitError("tensor literals must be ranked and have static shape"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
switch (getToken().getKind()) {
|
||||
case Token::floatliteral:
|
||||
case Token::integer:
|
||||
case Token::minus: {
|
||||
auto *scalar = parseAttribute();
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getSplatElementsAttr(type, scalar);
|
||||
}
|
||||
default:
|
||||
return (
|
||||
emitError("expected '[' or scalar constant inside tensor literal"),
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
default: {
|
||||
if (Type *type = parseType())
|
||||
return builder.getTypeAttr(type);
|
||||
|
|
|
@ -94,6 +94,7 @@ TOK_KEYWORD(ceildiv)
|
|||
TOK_KEYWORD(cfgfunc)
|
||||
TOK_KEYWORD(cond_br)
|
||||
TOK_KEYWORD(else)
|
||||
TOK_KEYWORD(splat)
|
||||
TOK_KEYWORD(extfunc)
|
||||
TOK_KEYWORD(f16)
|
||||
TOK_KEYWORD(f32)
|
||||
|
|
|
@ -484,3 +484,17 @@ mlfunc @mlfuncsimplemap(%arg0 : index, %arg1 : index) -> () {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @tensorattr
|
||||
cfgfunc @tensorattr() -> () {
|
||||
bb0:
|
||||
// CHECK: "splatIntTensor"() {bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
|
||||
"splatIntTensor"(){bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
|
||||
// CHECK: "splatFloatTensor"() {bar: splat<tensor<2x1x4xf32>, -5.000000e+00>} : () -> ()
|
||||
"splatFloatTensor"(){bar: splat<tensor<2x1x4xf32>, -5.0>} : () -> ()
|
||||
// CHECK: "splatIntVector"() {bar: splat<vector<2x1x4xi64>, 5>} : () -> ()
|
||||
"splatIntVector"(){bar: splat<vector<2x1x4xi64>, 5>} : () -> ()
|
||||
// CHECK: "splatFloatVector"() {bar: splat<vector<2x1x4xf16>, -5.000000e+00>} : () -> ()
|
||||
"splatFloatVector"(){bar: splat<vector<2x1x4xf16>, -5.0>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue