forked from OSchip/llvm-project
[mlir][linalg] Allow TC ops taking an unused shaped operand.
If one operand is not used in the formula, it will be considered a shaped operand. And the result of indexing map of the operand will be the first reduction dims. Depends On D97383 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D97384
This commit is contained in:
parent
4941fef9c4
commit
855a119604
|
@ -582,8 +582,9 @@ better adapt to Linalg:
|
|||
resorting to more general MLIR parsing.
|
||||
1. Reduction dimensions are specified with angle bracket notation on the
|
||||
operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
|
||||
dimension). In TC, a reduction is specified with `op=` operator and the
|
||||
reduction dimensions are inferred.
|
||||
dimension). In TC, the reduction dimensions are inferred. If one of the
|
||||
operand is not used in any expressions, it will be considered a shape-only
|
||||
operand, and the result of the indexing_map will be reduction dimensions.
|
||||
1. The parallel and reduction dimension are ordered by the textual program
|
||||
order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
|
||||
`i` (resp. `j`) is a parallel iterator encoded by affine dimension of
|
||||
|
|
|
@ -190,3 +190,14 @@ def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M))
|
|||
{
|
||||
C(m) = std_subf<k>(std_mulf(A(m, k), B(k)), C(m));
|
||||
}
|
||||
|
||||
// Test shape-only operand.
|
||||
// IMPL-LABEL: ArrayAttr Test9Op::indexing_maps() {
|
||||
// IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context);
|
||||
// IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context);
|
||||
// IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context);
|
||||
ods_def<Test9Op>:
|
||||
def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M))
|
||||
{
|
||||
C(m) = std_addf<k>(C(m), A(m, k));
|
||||
}
|
||||
|
|
|
@ -1634,7 +1634,26 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
|
|||
tensor.indexingMap = use.indexingMap;
|
||||
state.orderedTensorArgs[use] = tensor.index;
|
||||
});
|
||||
state.numArgs = seenDefs.size();
|
||||
// If more than one definitions are less. They are shaped-only operand, which
|
||||
// are used to define reduction loops. For now, only accept exactly one
|
||||
// shaped-only operand.
|
||||
if (state.numArgs > seenDefs.size() + 1) {
|
||||
failed = true;
|
||||
} else if (state.numArgs == seenDefs.size() + 1) {
|
||||
for (auto &tensorIter : registeredTensors) {
|
||||
auto &tensor = tensorIter.getValue();
|
||||
if (tensor.indexingMap)
|
||||
continue;
|
||||
if (auto *pTensorExpr =
|
||||
dyn_cast<TensorExpr>(state.expressions[0].get())) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (auto dim : pTensorExpr->reductionDimensions)
|
||||
exprs.push_back(getAffineDimExpr(dim, parser.context));
|
||||
tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
|
||||
exprs, parser.context);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (failed)
|
||||
return failure();
|
||||
|
||||
|
@ -1762,6 +1781,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
|
|||
SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
|
||||
while (parser.curToken.isNot(Token::Kind::r_brace)) {
|
||||
perComprehensionStates.push_back(ComprehensionParsingState());
|
||||
perComprehensionStates.back().numArgs = registeredTensors.size();
|
||||
if (failed(parseOneComprehension(cppOpName, tcName,
|
||||
perComprehensionStates.back())))
|
||||
return failure();
|
||||
|
@ -2207,10 +2227,6 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
std::string mapsStr;
|
||||
llvm::raw_string_ostream mapsStringStream(mapsStr);
|
||||
|
||||
SmallVector<TensorUse, 4> orderedUses(state.numArgs);
|
||||
for (const auto &it : state.orderedTensorArgs)
|
||||
orderedUses[it.second] = it.first;
|
||||
|
||||
// Create a list of all symbols.
|
||||
SmallVector<std::string, 4> symbolReplacements;
|
||||
symbolReplacements.reserve(symbols.size());
|
||||
|
@ -2242,10 +2258,11 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
|
||||
}
|
||||
|
||||
// For each tensor use, construct the affine map, replace symbols by the
|
||||
// corresponding attribute values, and simplify the affine map.
|
||||
for (auto tensorUse : llvm::enumerate(orderedUses)) {
|
||||
auto indexingMap = tensorUse.value().indexingMap;
|
||||
// For each registered tensor, construct the affine map, replace symbols by
|
||||
// the corresponding attribute values, and simplify the affine map.
|
||||
for (auto &tensorIter : registeredTensors) {
|
||||
auto &tensor = tensorIter.getValue();
|
||||
auto indexingMap = tensor.indexingMap;
|
||||
const char *mapFmt =
|
||||
"\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
|
||||
|
||||
|
@ -2255,8 +2272,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
|
||||
exprsStringStream << "}";
|
||||
exprsStringStream.flush();
|
||||
mapsStringStream << llvm::formatv(mapFmt, tensorUse.index(),
|
||||
state.dims.size(),
|
||||
mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
|
||||
indexingMap.getNumSymbols(), exprsStr);
|
||||
|
||||
std::string replaceSymbolList =
|
||||
|
@ -2269,17 +2285,17 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|||
// need that.
|
||||
const char *replaceFmt =
|
||||
"\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
|
||||
mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),
|
||||
mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
|
||||
replaceSymbolList, state.dims.size());
|
||||
const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
|
||||
mapsStringStream << llvm::formatv(simplifyFmt, tensorUse.index());
|
||||
mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
|
||||
}
|
||||
|
||||
mapsStringStream.flush();
|
||||
|
||||
SmallVector<std::string, 4> mapList;
|
||||
mapList.reserve(orderedUses.size());
|
||||
for (unsigned i = 0; i < orderedUses.size(); ++i)
|
||||
mapList.reserve(state.numArgs);
|
||||
for (auto i : llvm::seq<unsigned>(0, state.numArgs))
|
||||
mapList.push_back(llvm::formatv("map{0}", i));
|
||||
|
||||
// 4. Apply format to 1. using 2. and 3.
|
||||
|
|
Loading…
Reference in New Issue