diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index 7627a017a07e..8c4c31da2fce 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -85,12 +85,12 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { // Test attribute definitions // ODS-LABEL: def Test4Op // ODS: F32ArrayAttr:$array_attr, -// ODS: F32:$f32_attr, +// ODS: F32Attr:$f32_attr, // ODS: RankedF32ElementsAttr<[4]>:$fvec_attr, -// ODS: I32:$i32_attr, -// ODS: I64:$i64_attr, +// ODS: I32Attr:$i32_attr, +// ODS: I64Attr:$i64_attr, // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, -// ODS: OptionalAttr:$optional_attr +// ODS: OptionalAttr:$optional_attr // // ODS: bool hasDynamicIndexingMaps(); // ODS: LogicalResult verifyIndexingMapRequiredAttributes(); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 9d2d26a5cbd2..7165a0fe89fe 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1174,7 +1174,7 @@ private: // Returns the function to get values at the given indices from this // attribute. - std::string getValueFn(ArrayRef indices) const; + llvm::Optional getValueFn(ArrayRef indices) const; }; //===--------------------------------------------------------------------===// @@ -1841,16 +1841,19 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, const auto &dims = attr.second.vectorDims; if (!dims.empty()) { + // Vector case SmallVector dimStrs; for (uint64_t dim : dims) dimStrs.push_back(std::to_string(dim)); odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType, llvm::join(dimStrs, ", ")); - } - - assert(dims.empty() || !attr.second.isArray); - if (attr.second.isArray) + } else if (attr.second.isArray) { + // Array case odsType = llvm::formatv("{0}ArrayAttr", odsType); + } else { + // Scalar case + odsType = llvm::formatv("{0}Attr", odsType); + } if (attr.second.isOptional) odsType = llvm::formatv("OptionalAttr<{0}>", odsType); @@ -2242,13 +2245,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef attrName = attrUse.value().attrName; auto it = registeredAttrs.find(attrName.str()); assert(it != registeredAttrs.end() && "uses should point to valid attr!"); - std::string getValueFn = it->second.getValueFn(attrUse.value().indices); - if (getValueFn.empty()) { + llvm::Optional getValueFn = + it->second.getValueFn(attrUse.value().indices); + if (!getValueFn) { (void)parser.emitError("unimplemented getValueFn for attribute: " + attrName); return; } - std::string cstVal = llvm::formatv("{0}().{1}", attrName, getValueFn); + std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn); const char *cstFmt = "\n\tauto cst{0} = getAffineConstantExpr({1}, context);"; mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal); @@ -2374,10 +2378,10 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, expressionsStr, yieldStr); } -std::string +llvm::Optional TCParser::RegisteredAttr::getValueFn(ArrayRef indices) const { if (isArray) - return ""; + return llvm::None; if (!vectorDims.empty()) { SmallVector indexStrs; @@ -2385,20 +2389,20 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef indices) const { indexStrs.push_back(std::to_string(index)); std::string indexList = llvm::join(indexStrs, ", "); if (elementType == "f32") - return llvm::formatv("getValue({ {0} })", indexList); + return llvm::formatv(".getValue({ {0} })", indexList).str(); if (elementType == "i32") - return llvm::formatv("getValue({ {0} })", indexList); + return llvm::formatv(".getValue({ {0} })", indexList).str(); if (elementType == "i64") - return llvm::formatv("getValue({ {0} })", indexList); + return llvm::formatv(".getValue({ {0} })", indexList).str(); - return ""; + return llvm::None; } if (elementType == "f32") - return "getValue().convertToFloat()"; + return std::string(".convertToFloat()"); if (elementType == "i32" || elementType == "i64") - return "getInt()"; - return ""; + return std::string(""); + return llvm::None; } /// Iterate over each Tensor Comprehension def.