[mlir][linalg] Add support for using scalar attributes in TC ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D97876
This commit is contained in:
Hanhan Wang 2021-03-10 01:51:00 -08:00
parent 5293287630
commit d5d4fb635e
2 changed files with 25 additions and 21 deletions

View File

@ -85,12 +85,12 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
// Test attribute definitions
// ODS-LABEL: def Test4Op
// ODS: F32ArrayAttr:$array_attr,
// ODS: F32:$f32_attr,
// ODS: F32Attr:$f32_attr,
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
// ODS: I32:$i32_attr,
// ODS: I64:$i64_attr,
// ODS: I32Attr:$i32_attr,
// ODS: I64Attr:$i64_attr,
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
// ODS: OptionalAttr<F32>:$optional_attr
// ODS: OptionalAttr<F32Attr>:$optional_attr
//
// ODS: bool hasDynamicIndexingMaps();
// ODS: LogicalResult verifyIndexingMapRequiredAttributes();

View File

@ -1174,7 +1174,7 @@ private:
// Returns the function to get values at the given indices from this
// attribute.
std::string getValueFn(ArrayRef<uint64_t> indices) const;
llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
};
//===--------------------------------------------------------------------===//
@ -1841,16 +1841,19 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const auto &dims = attr.second.vectorDims;
if (!dims.empty()) {
// Vector case
SmallVector<std::string, 4> dimStrs;
for (uint64_t dim : dims)
dimStrs.push_back(std::to_string(dim));
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
llvm::join(dimStrs, ", "));
}
assert(dims.empty() || !attr.second.isArray);
if (attr.second.isArray)
} else if (attr.second.isArray) {
// Array case
odsType = llvm::formatv("{0}ArrayAttr", odsType);
} else {
// Scalar case
odsType = llvm::formatv("{0}Attr", odsType);
}
if (attr.second.isOptional)
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
@ -2242,13 +2245,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
StringRef attrName = attrUse.value().attrName;
auto it = registeredAttrs.find(attrName.str());
assert(it != registeredAttrs.end() && "uses should point to valid attr!");
std::string getValueFn = it->second.getValueFn(attrUse.value().indices);
if (getValueFn.empty()) {
llvm::Optional<std::string> getValueFn =
it->second.getValueFn(attrUse.value().indices);
if (!getValueFn) {
(void)parser.emitError("unimplemented getValueFn for attribute: " +
attrName);
return;
}
std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn);
std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
const char *cstFmt =
"\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
@ -2374,10 +2378,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
expressionsStr, yieldStr);
}
std::string
llvm::Optional<std::string>
TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
if (isArray)
return "";
return llvm::None;
if (!vectorDims.empty()) {
SmallVector<std::string, 4> indexStrs;
@ -2385,20 +2389,20 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
indexStrs.push_back(std::to_string(index));
std::string indexList = llvm::join(indexStrs, ", ");
if (elementType == "f32")
return llvm::formatv("getValue<float>({ {0} })", indexList);
return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
if (elementType == "i32")
return llvm::formatv("getValue<int>({ {0} })", indexList);
return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
if (elementType == "i64")
return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
return "";
return llvm::None;
}
if (elementType == "f32")
return "getValue().convertToFloat()";
return std::string(".convertToFloat()");
if (elementType == "i32" || elementType == "i64")
return "getInt()";
return "";
return std::string("");
return llvm::None;
}
/// Iterate over each Tensor Comprehension def.