forked from OSchip/llvm-project
[mlir][linalg] Add support for using scalar attributes in TC ops.
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D97876
This commit is contained in:
parent
5293287630
commit
d5d4fb635e
|
@ -85,12 +85,12 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
|
|||
// Test attribute definitions
|
||||
// ODS-LABEL: def Test4Op
|
||||
// ODS: F32ArrayAttr:$array_attr,
|
||||
// ODS: F32:$f32_attr,
|
||||
// ODS: F32Attr:$f32_attr,
|
||||
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
|
||||
// ODS: I32:$i32_attr,
|
||||
// ODS: I64:$i64_attr,
|
||||
// ODS: I32Attr:$i32_attr,
|
||||
// ODS: I64Attr:$i64_attr,
|
||||
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
|
||||
// ODS: OptionalAttr<F32>:$optional_attr
|
||||
// ODS: OptionalAttr<F32Attr>:$optional_attr
|
||||
//
|
||||
// ODS: bool hasDynamicIndexingMaps();
|
||||
// ODS: LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
|
|
|
@ -1174,7 +1174,7 @@ private:
|
|||
|
||||
// Returns the function to get values at the given indices from this
|
||||
// attribute.
|
||||
std::string getValueFn(ArrayRef<uint64_t> indices) const;
|
||||
llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -1841,16 +1841,19 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
|
||||
const auto &dims = attr.second.vectorDims;
|
||||
if (!dims.empty()) {
|
||||
// Vector case
|
||||
SmallVector<std::string, 4> dimStrs;
|
||||
for (uint64_t dim : dims)
|
||||
dimStrs.push_back(std::to_string(dim));
|
||||
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
|
||||
llvm::join(dimStrs, ", "));
|
||||
}
|
||||
|
||||
assert(dims.empty() || !attr.second.isArray);
|
||||
if (attr.second.isArray)
|
||||
} else if (attr.second.isArray) {
|
||||
// Array case
|
||||
odsType = llvm::formatv("{0}ArrayAttr", odsType);
|
||||
} else {
|
||||
// Scalar case
|
||||
odsType = llvm::formatv("{0}Attr", odsType);
|
||||
}
|
||||
|
||||
if (attr.second.isOptional)
|
||||
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
|
||||
|
@ -2242,13 +2245,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
StringRef attrName = attrUse.value().attrName;
|
||||
auto it = registeredAttrs.find(attrName.str());
|
||||
assert(it != registeredAttrs.end() && "uses should point to valid attr!");
|
||||
std::string getValueFn = it->second.getValueFn(attrUse.value().indices);
|
||||
if (getValueFn.empty()) {
|
||||
llvm::Optional<std::string> getValueFn =
|
||||
it->second.getValueFn(attrUse.value().indices);
|
||||
if (!getValueFn) {
|
||||
(void)parser.emitError("unimplemented getValueFn for attribute: " +
|
||||
attrName);
|
||||
return;
|
||||
}
|
||||
std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn);
|
||||
std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
|
||||
const char *cstFmt =
|
||||
"\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
|
||||
mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
|
||||
|
@ -2374,10 +2378,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
|||
expressionsStr, yieldStr);
|
||||
}
|
||||
|
||||
std::string
|
||||
llvm::Optional<std::string>
|
||||
TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
|
||||
if (isArray)
|
||||
return "";
|
||||
return llvm::None;
|
||||
|
||||
if (!vectorDims.empty()) {
|
||||
SmallVector<std::string, 4> indexStrs;
|
||||
|
@ -2385,20 +2389,20 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
|
|||
indexStrs.push_back(std::to_string(index));
|
||||
std::string indexList = llvm::join(indexStrs, ", ");
|
||||
if (elementType == "f32")
|
||||
return llvm::formatv("getValue<float>({ {0} })", indexList);
|
||||
return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
|
||||
if (elementType == "i32")
|
||||
return llvm::formatv("getValue<int>({ {0} })", indexList);
|
||||
return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
|
||||
if (elementType == "i64")
|
||||
return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
|
||||
return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
|
||||
|
||||
return "";
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (elementType == "f32")
|
||||
return "getValue().convertToFloat()";
|
||||
return std::string(".convertToFloat()");
|
||||
if (elementType == "i32" || elementType == "i64")
|
||||
return "getInt()";
|
||||
return "";
|
||||
return std::string("");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Iterate over each Tensor Comprehension def.
|
||||
|
|
Loading…
Reference in New Issue