[mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.

This is the first step in replacing interator_type from strings with enums in Vector and Linalg dialect. This change adds IteratorTypeAttr and uses it in ContractionOp.

To avoid breaking all the tests, print/parse code has conversion between string and enum for now.

There is a shared code in StructuredOpsUtils.h that expects iterator types to be strings. To break this dependancy, this change forks helper function `isParallelIterator` and `isReductionIterator` to utils in both dialects and adds `getIteratorTypeNames()` to support backward compatibility with StructuredGenerator.

In the later changes, I plan to add a similar enum attribute to Linalg.

Differential Revision: https://reviews.llvm.org/D133696
This commit is contained in:
Oleg Shyshkov 2022-09-12 16:53:36 +02:00
parent 7ecd4d2b7c
commit 4758e916e1
13 changed files with 143 additions and 70 deletions

View File

@ -727,6 +727,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
// TODO: Remove once prefixing is flipped.
ArrayAttr getIteratorTypes() { return iterator_types(); }
SmallVector<StringRef> getIteratorTypeNames() {
return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
}
//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.

View File

@ -45,6 +45,12 @@ bool isElementwise(LinalgOp op);
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
/// Check if `attr` has "parallel" iterator type semantics.
bool isParallelIterator(Attribute attr);
/// Check if `attr` has "reduction" iterator type semantics.
bool isReductionIterator(Attribute attr);
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

View File

@ -78,24 +78,12 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
inline bool isParallelIterator(Attribute attr) {
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
}
/// Use to encode that a particular iterator type has reduction semantics.
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
inline bool isReductionIterator(Attribute attr) {
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
}
/// Use to encode that a particular iterator type has window semantics.
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
inline bool isWindowIterator(Attribute attr) {
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
}
/// Use to encode that a particular iterator type has window semantics.
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
@ -122,19 +110,6 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
return res;
}
/// Typed representation for loop type strings.
enum class IteratorType { Parallel, Reduction };
inline StringRef toString(IteratorType t) {
switch (t) {
case IteratorType::Parallel:
return getParallelIteratorTypeName();
case IteratorType::Reduction:
return getReductionIteratorTypeName();
}
llvm_unreachable("Unsupported IteratorType");
}
/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
@ -145,10 +120,7 @@ public:
struct IteratorType {
IteratorType(StringRef strRef) : strRef(strRef) {}
bool isOfType(Attribute attr) const {
auto sAttr = attr.dyn_cast<StringAttr>();
return sAttr && sAttr.getValue() == strRef;
}
bool isOfType(StringRef typeName) const { return typeName == strRef; }
StringRef strRef;
};
struct Par : public IteratorType {
@ -163,7 +135,7 @@ public:
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()),
iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
op(op) {}
bool iters(ArrayRef<IteratorType> its) {
@ -185,7 +157,7 @@ protected:
OpBuilder &builder;
MLIRContext *ctx;
Location loc;
ArrayAttr iterators;
SmallVector<StringRef> iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};

View File

@ -185,6 +185,17 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
}
/// Returns true if `attr` has "reduction" iterator type semantics.
inline bool isReductionIterator(Attribute attr) {
return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
}
} // namespace vector
} // namespace mlir

View File

@ -63,6 +63,21 @@ def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}
def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}
def IteratorTypeEnum : EnumAttr<Vector_Dialect, IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}
def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
"Iterator type should be an enum.">;
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@ -76,7 +91,7 @@ def Vector_ContractionOp :
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
ArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
"CombiningKind::ADD">:$kind)>,
Results<(outs AnyType)> {
@ -201,7 +216,7 @@ def Vector_ContractionOp :
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayRef<ArrayRef<AffineExpr>>":$indexingExprs,
"ArrayRef<StringRef>":$iteratorTypes)>,
"ArrayRef<IteratorType>":$iteratorTypes)>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
"CombiningKind":$kind)>
@ -249,6 +264,14 @@ def Vector_ContractionOp :
static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
}
// Returns iterator types in string format.
SmallVector<StringRef> getIteratorTypeNames() {
return llvm::to_vector(
llvm::map_range(getIteratorTypes(), [](Attribute a) {
return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
}));
}
}];
let hasCanonicalizer = 1;

View File

@ -217,10 +217,10 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
params.targetLayout = NVVM::MMALayout::col;
}
ArrayRef<int64_t> shape = type.vectorType.getShape();
params.contiguousDimType =
transpose ? IteratorType::Parallel : IteratorType::Reduction;
params.contiguousDimType = transpose ? vector::IteratorType::parallel
: vector::IteratorType::reduction;
if (params.contiguousDimType == IteratorType::Reduction) {
if (params.contiguousDimType == vector::IteratorType::reduction) {
params.numTiles = (shape[0] / kNumRowsPerTile) *
((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
} else {
@ -250,7 +250,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
};
// This case corresponds to row-major A|C or col-major B operands.
if (params.contiguousDimType == IteratorType::Reduction) {
if (params.contiguousDimType == vector::IteratorType::reduction) {
AffineExpr row = d0 % (operandShape[0]);
AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
return makeMap({row, col});
@ -258,7 +258,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
// This case Corresponds to col-major A|C or row-major B operands. The
// operandShape given is already pre-transposed (e.g. 8x16 = KxN).
if (params.contiguousDimType == IteratorType::Parallel) {
if (params.contiguousDimType == vector::IteratorType::parallel) {
const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
// Threads are assigned in groups of 8 first across columns, then to
// rows. This is transpose of what `ldmatrix` expects, but when
@ -293,9 +293,9 @@ PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
if (iteratorTypes.size() != 3)
return failure();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])))
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
vector::isParallelIterator(iteratorTypes[1]) &&
vector::isReductionIterator(iteratorTypes[2])))
return failure();
// The canonical form is "TNT" = A row-major, B col-major, C row-major.

View File

@ -71,7 +71,7 @@ struct LdMatrixParams {
VectorType fragmentType;
bool isAccum;
int64_t numTiles;
IteratorType contiguousDimType;
vector::IteratorType contiguousDimType;
NVVM::MMALayout targetLayout;
};

View File

@ -74,9 +74,9 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
AffineExpr m, n, k;
bindDims(contract.getContext(), m, n, k);
auto iteratorTypes = contract.getIteratorTypes().getValue();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])))
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
vector::isParallelIterator(iteratorTypes[1]) &&
vector::isReductionIterator(iteratorTypes[2])))
return false;
// The contract needs to represent a matmul to be able to convert to
@ -296,9 +296,9 @@ struct PrepareContractToGPUMMA
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.getIteratorTypes().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])))
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
vector::isParallelIterator(iteratorTypes[1]) &&
vector::isReductionIterator(iteratorTypes[2])))
return failure();
//
// Two outer parallel, one inner reduction (matmat flavor).

View File

@ -1488,13 +1488,14 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
Value rhs, Value res) {
StringRef par = Par().strRef, red = Red().strRef;
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
return builder.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
}
/// Generate a vector implementation for:

View File

@ -199,6 +199,16 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
bool isParallelIterator(Attribute attr) {
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
}
bool isReductionIterator(Attribute attr) {
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {

View File

@ -350,7 +350,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
if (isMaterializing(lhs->get())) {
unsigned nest = 0;
for (unsigned i = 0; i < numLoops; i++) {
if (isReductionIterator(iteratorTypes[topSort[i]]))
if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
break; // terminate at first reduction
nest++;
}
@ -1234,7 +1234,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb));
auto iteratorTypes = op.iterator_types().getValue();
bool isReduction = isReductionIterator(iteratorTypes[idx]);
bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
bool isSparse = merger.isDim(fb, Dim::kSparse);
bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
denseUnitStrides(merger, op, idx);

View File

@ -455,14 +455,18 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc,
ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
ArrayRef<StringRef> iteratorTypes) {
ArrayRef<IteratorType> iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(::mlir::getIndexingMapsAttrName(),
builder.getAffineMapArrayAttr(
AffineMap::inferFromExprList(indexingExprs)));
result.addAttribute(::mlir::getIteratorTypesAttrName(),
builder.getStrArrayAttr(iteratorTypes));
result.addAttribute(
::mlir::getIteratorTypesAttrName(),
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
return IteratorTypeAttr::get(builder.getContext(), t);
}))));
}
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
@ -510,6 +514,27 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
// Convert array of string into an array of IteratyType enums. This is needed,
// because tests still use the old format when 'iterator_types' attribute is
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
ArrayAttr iteratorTypes =
result.attributes.get("iterator_types").cast<ArrayAttr>();
SmallVector<Attribute> iteratorTypeAttrs;
for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
auto maybeIteratorType = symbolizeIteratorType(s);
if (!maybeIteratorType.hasValue())
return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
parser.getContext(), maybeIteratorType.getValue()));
}
result.attributes.set("iterator_types",
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
result.addAttribute(
ContractionOp::getKindAttrStrName(),
@ -538,9 +563,26 @@ void ContractionOp::print(OpAsmPrinter &p) {
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
for (auto attr : (*this)->getAttrs())
if (traitAttrsSet.count(attr.getName().strref()) > 0)
for (auto attr : (*this)->getAttrs()) {
if (attr.getName() == getIteratorTypesAttrName()) {
auto iteratorTypes =
attr.getValue()
.cast<ArrayAttr>()
.getAsValueRange<IteratorTypeAttr, IteratorType>();
// Convert IteratorType enums into the string representation. This is
// needed, because tests still use the old format when 'iterator_types'
// attribute is represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
return StringAttr::get(getContext(), stringifyIteratorType(t));
}));
attrs.emplace_back(getIteratorTypesAttrName(),
ArrayAttr::get(getContext(), iteratorTypeNames));
} else if (traitAttrsSet.count(attr.getName().strref()) > 0)
attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
p << " " << dictAttr << " " << getLhs() << ", ";
@ -746,11 +788,11 @@ static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
StringRef targetIteratorTypeName, MLIRContext *context) {
IteratorType targetIteratorType, MLIRContext *context) {
std::vector<std::pair<int64_t, int64_t>> dimMap;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
if (iteratorTypeName != targetIteratorTypeName)
auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
if (iteratorType != targetIteratorType)
continue;
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), context);
@ -771,8 +813,8 @@ void ContractionOp::getIterationBounds(
for (const auto &it : llvm::enumerate(getIteratorTypes())) {
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), getContext());
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
if (iteratorTypeName == getReductionIteratorTypeName()) {
auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
if (iteratorType == IteratorType::reduction) {
// Get reduction dim size from lhs shape (same size in rhsShape).
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
assert(lhsDimIndex >= 0);
@ -803,14 +845,14 @@ void ContractionOp::getIterationIndexMap(
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
return getDimMap(indexingMaps, getIteratorTypes(),
getReductionIteratorTypeName(), getContext());
return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
getContext());
}
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
return getDimMap(indexingMaps, getIteratorTypes(),
getParallelIteratorTypeName(), getContext());
return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
getContext());
}
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {

View File

@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@ -986,13 +987,13 @@ struct MultiReduceToContract
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<vector::IteratorType> iteratorTypes;
for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
if (!isReduceDim.value()) {
iteratorTypes.push_back(getParallelIteratorTypeName());
iteratorTypes.push_back(vector::IteratorType::parallel);
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
} else {
iteratorTypes.push_back(getReductionIteratorTypeName());
iteratorTypes.push_back(vector::IteratorType::reduction);
}
}
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
@ -1000,7 +1001,10 @@ struct MultiReduceToContract
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
rewriter.getStrArrayAttr(iteratorTypes));
rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
return IteratorTypeAttr::get(rewriter.getContext(), t);
}))));
return success();
}
};