forked from OSchip/llvm-project
[mlir] Change IteratorType in ContractionOp in Vector dialect from string to enum.
This is the first step in replacing interator_type from strings with enums in Vector and Linalg dialect. This change adds IteratorTypeAttr and uses it in ContractionOp. To avoid breaking all the tests, print/parse code has conversion between string and enum for now. There is a shared code in StructuredOpsUtils.h that expects iterator types to be strings. To break this dependancy, this change forks helper function `isParallelIterator` and `isReductionIterator` to utils in both dialects and adds `getIteratorTypeNames()` to support backward compatibility with StructuredGenerator. In the later changes, I plan to add a similar enum attribute to Linalg. Differential Revision: https://reviews.llvm.org/D133696
This commit is contained in:
parent
7ecd4d2b7c
commit
4758e916e1
|
@ -727,6 +727,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
// TODO: Remove once prefixing is flipped.
|
||||
ArrayAttr getIteratorTypes() { return iterator_types(); }
|
||||
|
||||
SmallVector<StringRef> getIteratorTypeNames() {
|
||||
return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
|
||||
}
|
||||
|
||||
//========================================================================//
|
||||
// Forwarding functions to access interface methods from the
|
||||
// DestinationStyleOpInterface.
|
||||
|
|
|
@ -45,6 +45,12 @@ bool isElementwise(LinalgOp op);
|
|||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
/// Check if `attr` has "parallel" iterator type semantics.
|
||||
bool isParallelIterator(Attribute attr);
|
||||
|
||||
/// Check if `attr` has "reduction" iterator type semantics.
|
||||
bool isReductionIterator(Attribute attr);
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
|
||||
|
|
|
@ -78,24 +78,12 @@ constexpr StringRef getPaddingAttrName() { return "padding"; }
|
|||
|
||||
/// Use to encode that a particular iterator type has parallel semantics.
|
||||
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
|
||||
inline bool isParallelIterator(Attribute attr) {
|
||||
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
|
||||
return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has reduction semantics.
|
||||
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
|
||||
inline bool isReductionIterator(Attribute attr) {
|
||||
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
|
||||
return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has window semantics.
|
||||
constexpr StringRef getWindowIteratorTypeName() { return "window"; }
|
||||
inline bool isWindowIterator(Attribute attr) {
|
||||
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
|
||||
return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
|
||||
}
|
||||
|
||||
/// Use to encode that a particular iterator type has window semantics.
|
||||
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
|
||||
|
@ -122,19 +110,6 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
|
|||
return res;
|
||||
}
|
||||
|
||||
/// Typed representation for loop type strings.
|
||||
enum class IteratorType { Parallel, Reduction };
|
||||
|
||||
inline StringRef toString(IteratorType t) {
|
||||
switch (t) {
|
||||
case IteratorType::Parallel:
|
||||
return getParallelIteratorTypeName();
|
||||
case IteratorType::Reduction:
|
||||
return getReductionIteratorTypeName();
|
||||
}
|
||||
llvm_unreachable("Unsupported IteratorType");
|
||||
}
|
||||
|
||||
/// Helper StructuredGenerator class to manipulate and rewrite ops with
|
||||
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
|
||||
/// yet implement the StructuredOpInterface itself.
|
||||
|
@ -145,10 +120,7 @@ public:
|
|||
|
||||
struct IteratorType {
|
||||
IteratorType(StringRef strRef) : strRef(strRef) {}
|
||||
bool isOfType(Attribute attr) const {
|
||||
auto sAttr = attr.dyn_cast<StringAttr>();
|
||||
return sAttr && sAttr.getValue() == strRef;
|
||||
}
|
||||
bool isOfType(StringRef typeName) const { return typeName == strRef; }
|
||||
StringRef strRef;
|
||||
};
|
||||
struct Par : public IteratorType {
|
||||
|
@ -163,7 +135,7 @@ public:
|
|||
|
||||
StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
|
||||
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
|
||||
iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()),
|
||||
iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
|
||||
op(op) {}
|
||||
|
||||
bool iters(ArrayRef<IteratorType> its) {
|
||||
|
@ -185,7 +157,7 @@ protected:
|
|||
OpBuilder &builder;
|
||||
MLIRContext *ctx;
|
||||
Location loc;
|
||||
ArrayAttr iterators;
|
||||
SmallVector<StringRef> iterators;
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
Operation *op;
|
||||
};
|
||||
|
|
|
@ -185,6 +185,17 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
|
|||
/// corresponding arith operation.
|
||||
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
|
||||
Value v1, Value v2);
|
||||
|
||||
/// Returns true if `attr` has "parallel" iterator type semantics.
|
||||
inline bool isParallelIterator(Attribute attr) {
|
||||
return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
|
||||
}
|
||||
|
||||
/// Returns true if `attr` has "reduction" iterator type semantics.
|
||||
inline bool isReductionIterator(Attribute attr) {
|
||||
return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
|
||||
}
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -63,6 +63,21 @@ def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
|
|||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
|
||||
I32EnumAttrCase<"parallel", 0>,
|
||||
I32EnumAttrCase<"reduction", 1>
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::vector";
|
||||
}
|
||||
|
||||
def IteratorTypeEnum : EnumAttr<Vector_Dialect, IteratorType, "iterator_type"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
|
||||
"Iterator type should be an enum.">;
|
||||
|
||||
// TODO: Add an attribute to specify a different algebra with operators other
|
||||
// than the current set: {*, +}.
|
||||
def Vector_ContractionOp :
|
||||
|
@ -76,7 +91,7 @@ def Vector_ContractionOp :
|
|||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
||||
Variadic<VectorOf<[I1]>>:$masks,
|
||||
ArrayAttr:$indexing_maps,
|
||||
ArrayAttr:$iterator_types,
|
||||
IteratorTypeArrayAttr:$iterator_types,
|
||||
DefaultValuedAttr<Vector_CombiningKindAttr,
|
||||
"CombiningKind::ADD">:$kind)>,
|
||||
Results<(outs AnyType)> {
|
||||
|
@ -201,7 +216,7 @@ def Vector_ContractionOp :
|
|||
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>,
|
||||
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
|
||||
"ArrayRef<ArrayRef<AffineExpr>>":$indexingExprs,
|
||||
"ArrayRef<StringRef>":$iteratorTypes)>,
|
||||
"ArrayRef<IteratorType>":$iteratorTypes)>,
|
||||
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc,
|
||||
"ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes,
|
||||
"CombiningKind":$kind)>
|
||||
|
@ -249,6 +264,14 @@ def Vector_ContractionOp :
|
|||
static CombiningKind getDefaultKind() {
|
||||
return CombiningKind::ADD;
|
||||
}
|
||||
|
||||
// Returns iterator types in string format.
|
||||
SmallVector<StringRef> getIteratorTypeNames() {
|
||||
return llvm::to_vector(
|
||||
llvm::map_range(getIteratorTypes(), [](Attribute a) {
|
||||
return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
|
||||
}));
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -217,10 +217,10 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
|
|||
params.targetLayout = NVVM::MMALayout::col;
|
||||
}
|
||||
ArrayRef<int64_t> shape = type.vectorType.getShape();
|
||||
params.contiguousDimType =
|
||||
transpose ? IteratorType::Parallel : IteratorType::Reduction;
|
||||
params.contiguousDimType = transpose ? vector::IteratorType::parallel
|
||||
: vector::IteratorType::reduction;
|
||||
|
||||
if (params.contiguousDimType == IteratorType::Reduction) {
|
||||
if (params.contiguousDimType == vector::IteratorType::reduction) {
|
||||
params.numTiles = (shape[0] / kNumRowsPerTile) *
|
||||
((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
|
||||
} else {
|
||||
|
@ -250,7 +250,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
|
|||
};
|
||||
|
||||
// This case corresponds to row-major A|C or col-major B operands.
|
||||
if (params.contiguousDimType == IteratorType::Reduction) {
|
||||
if (params.contiguousDimType == vector::IteratorType::reduction) {
|
||||
AffineExpr row = d0 % (operandShape[0]);
|
||||
AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
|
||||
return makeMap({row, col});
|
||||
|
@ -258,7 +258,7 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
|
|||
|
||||
// This case Corresponds to col-major A|C or row-major B operands. The
|
||||
// operandShape given is already pre-transposed (e.g. 8x16 = KxN).
|
||||
if (params.contiguousDimType == IteratorType::Parallel) {
|
||||
if (params.contiguousDimType == vector::IteratorType::parallel) {
|
||||
const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
|
||||
// Threads are assigned in groups of 8 first across columns, then to
|
||||
// rows. This is transpose of what `ldmatrix` expects, but when
|
||||
|
@ -293,9 +293,9 @@ PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
|
|||
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
|
||||
if (iteratorTypes.size() != 3)
|
||||
return failure();
|
||||
if (!(isParallelIterator(iteratorTypes[0]) &&
|
||||
isParallelIterator(iteratorTypes[1]) &&
|
||||
isReductionIterator(iteratorTypes[2])))
|
||||
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
|
||||
vector::isParallelIterator(iteratorTypes[1]) &&
|
||||
vector::isReductionIterator(iteratorTypes[2])))
|
||||
return failure();
|
||||
|
||||
// The canonical form is "TNT" = A row-major, B col-major, C row-major.
|
||||
|
|
|
@ -71,7 +71,7 @@ struct LdMatrixParams {
|
|||
VectorType fragmentType;
|
||||
bool isAccum;
|
||||
int64_t numTiles;
|
||||
IteratorType contiguousDimType;
|
||||
vector::IteratorType contiguousDimType;
|
||||
NVVM::MMALayout targetLayout;
|
||||
};
|
||||
|
||||
|
|
|
@ -74,9 +74,9 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
|
|||
AffineExpr m, n, k;
|
||||
bindDims(contract.getContext(), m, n, k);
|
||||
auto iteratorTypes = contract.getIteratorTypes().getValue();
|
||||
if (!(isParallelIterator(iteratorTypes[0]) &&
|
||||
isParallelIterator(iteratorTypes[1]) &&
|
||||
isReductionIterator(iteratorTypes[2])))
|
||||
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
|
||||
vector::isParallelIterator(iteratorTypes[1]) &&
|
||||
vector::isReductionIterator(iteratorTypes[2])))
|
||||
return false;
|
||||
|
||||
// The contract needs to represent a matmul to be able to convert to
|
||||
|
@ -296,9 +296,9 @@ struct PrepareContractToGPUMMA
|
|||
static constexpr std::array<int64_t, 2> perm = {1, 0};
|
||||
auto iteratorTypes = op.getIteratorTypes().getValue();
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
|
||||
if (!(isParallelIterator(iteratorTypes[0]) &&
|
||||
isParallelIterator(iteratorTypes[1]) &&
|
||||
isReductionIterator(iteratorTypes[2])))
|
||||
if (!(vector::isParallelIterator(iteratorTypes[0]) &&
|
||||
vector::isParallelIterator(iteratorTypes[1]) &&
|
||||
vector::isReductionIterator(iteratorTypes[2])))
|
||||
return failure();
|
||||
//
|
||||
// Two outer parallel, one inner reduction (matmat flavor).
|
||||
|
|
|
@ -1488,13 +1488,14 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
|
||||
Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
|
||||
Value rhs, Value res) {
|
||||
StringRef par = Par().strRef, red = Red().strRef;
|
||||
vector::IteratorType par = vector::IteratorType::parallel;
|
||||
vector::IteratorType red = vector::IteratorType::reduction;
|
||||
AffineExpr n, w, f, c;
|
||||
bindDims(ctx, n, w, f, c);
|
||||
return builder.create<vector::ContractionOp>(
|
||||
loc, lhs, rhs, res,
|
||||
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
|
||||
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
|
||||
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
|
||||
}
|
||||
|
||||
/// Generate a vector implementation for:
|
||||
|
|
|
@ -199,6 +199,16 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
|
|||
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
|
||||
}
|
||||
|
||||
bool isParallelIterator(Attribute attr) {
|
||||
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
|
||||
return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
|
||||
}
|
||||
|
||||
bool isReductionIterator(Attribute attr) {
|
||||
auto strAttr = attr.dyn_cast_or_null<StringAttr>();
|
||||
return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
|
||||
}
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
|
||||
|
|
|
@ -350,7 +350,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
|
|||
if (isMaterializing(lhs->get())) {
|
||||
unsigned nest = 0;
|
||||
for (unsigned i = 0; i < numLoops; i++) {
|
||||
if (isReductionIterator(iteratorTypes[topSort[i]]))
|
||||
if (linalg::isReductionIterator(iteratorTypes[topSort[i]]))
|
||||
break; // terminate at first reduction
|
||||
nest++;
|
||||
}
|
||||
|
@ -1234,7 +1234,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
|||
unsigned tensor = merger.tensor(fb);
|
||||
assert(idx == merger.index(fb));
|
||||
auto iteratorTypes = op.iterator_types().getValue();
|
||||
bool isReduction = isReductionIterator(iteratorTypes[idx]);
|
||||
bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
|
||||
bool isSparse = merger.isDim(fb, Dim::kSparse);
|
||||
bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
|
||||
denseUnitStrides(merger, op, idx);
|
||||
|
|
|
@ -455,14 +455,18 @@ void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value lhs, Value rhs, Value acc,
|
||||
ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
|
||||
ArrayRef<StringRef> iteratorTypes) {
|
||||
ArrayRef<IteratorType> iteratorTypes) {
|
||||
result.addOperands({lhs, rhs, acc});
|
||||
result.addTypes(acc.getType());
|
||||
result.addAttribute(::mlir::getIndexingMapsAttrName(),
|
||||
builder.getAffineMapArrayAttr(
|
||||
AffineMap::inferFromExprList(indexingExprs)));
|
||||
result.addAttribute(::mlir::getIteratorTypesAttrName(),
|
||||
builder.getStrArrayAttr(iteratorTypes));
|
||||
result.addAttribute(
|
||||
::mlir::getIteratorTypesAttrName(),
|
||||
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
|
||||
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
|
||||
return IteratorTypeAttr::get(builder.getContext(), t);
|
||||
}))));
|
||||
}
|
||||
|
||||
void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
|
||||
|
@ -510,6 +514,27 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
return failure();
|
||||
result.attributes.assign(dictAttr.getValue().begin(),
|
||||
dictAttr.getValue().end());
|
||||
|
||||
// Convert array of string into an array of IteratyType enums. This is needed,
|
||||
// because tests still use the old format when 'iterator_types' attribute is
|
||||
// represented as an array of strings.
|
||||
// TODO: Remove this conversion once tests are fixed.
|
||||
ArrayAttr iteratorTypes =
|
||||
result.attributes.get("iterator_types").cast<ArrayAttr>();
|
||||
|
||||
SmallVector<Attribute> iteratorTypeAttrs;
|
||||
|
||||
for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
|
||||
auto maybeIteratorType = symbolizeIteratorType(s);
|
||||
if (!maybeIteratorType.hasValue())
|
||||
return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
|
||||
|
||||
iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
|
||||
parser.getContext(), maybeIteratorType.getValue()));
|
||||
}
|
||||
result.attributes.set("iterator_types",
|
||||
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
|
||||
|
||||
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
|
||||
result.addAttribute(
|
||||
ContractionOp::getKindAttrStrName(),
|
||||
|
@ -538,9 +563,26 @@ void ContractionOp::print(OpAsmPrinter &p) {
|
|||
llvm::StringSet<> traitAttrsSet;
|
||||
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
|
||||
SmallVector<NamedAttribute, 8> attrs;
|
||||
for (auto attr : (*this)->getAttrs())
|
||||
if (traitAttrsSet.count(attr.getName().strref()) > 0)
|
||||
for (auto attr : (*this)->getAttrs()) {
|
||||
if (attr.getName() == getIteratorTypesAttrName()) {
|
||||
auto iteratorTypes =
|
||||
attr.getValue()
|
||||
.cast<ArrayAttr>()
|
||||
.getAsValueRange<IteratorTypeAttr, IteratorType>();
|
||||
// Convert IteratorType enums into the string representation. This is
|
||||
// needed, because tests still use the old format when 'iterator_types'
|
||||
// attribute is represented as an array of strings.
|
||||
// TODO: Remove this conversion once tests are fixed.
|
||||
SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
|
||||
llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
|
||||
return StringAttr::get(getContext(), stringifyIteratorType(t));
|
||||
}));
|
||||
|
||||
attrs.emplace_back(getIteratorTypesAttrName(),
|
||||
ArrayAttr::get(getContext(), iteratorTypeNames));
|
||||
} else if (traitAttrsSet.count(attr.getName().strref()) > 0)
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
|
||||
auto dictAttr = DictionaryAttr::get(getContext(), attrs);
|
||||
p << " " << dictAttr << " " << getLhs() << ", ";
|
||||
|
@ -746,11 +788,11 @@ static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
|
|||
|
||||
static std::vector<std::pair<int64_t, int64_t>>
|
||||
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
||||
StringRef targetIteratorTypeName, MLIRContext *context) {
|
||||
IteratorType targetIteratorType, MLIRContext *context) {
|
||||
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
||||
for (const auto &it : llvm::enumerate(iteratorTypes)) {
|
||||
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
||||
if (iteratorTypeName != targetIteratorTypeName)
|
||||
auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
|
||||
if (iteratorType != targetIteratorType)
|
||||
continue;
|
||||
// Search lhs/rhs map results for 'targetExpr'.
|
||||
auto targetExpr = getAffineDimExpr(it.index(), context);
|
||||
|
@ -771,8 +813,8 @@ void ContractionOp::getIterationBounds(
|
|||
for (const auto &it : llvm::enumerate(getIteratorTypes())) {
|
||||
// Search lhs/rhs map results for 'targetExpr'.
|
||||
auto targetExpr = getAffineDimExpr(it.index(), getContext());
|
||||
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
||||
if (iteratorTypeName == getReductionIteratorTypeName()) {
|
||||
auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
|
||||
if (iteratorType == IteratorType::reduction) {
|
||||
// Get reduction dim size from lhs shape (same size in rhsShape).
|
||||
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
|
||||
assert(lhsDimIndex >= 0);
|
||||
|
@ -803,14 +845,14 @@ void ContractionOp::getIterationIndexMap(
|
|||
|
||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
|
||||
return getDimMap(indexingMaps, getIteratorTypes(),
|
||||
getReductionIteratorTypeName(), getContext());
|
||||
return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
|
||||
getContext());
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
|
||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
|
||||
return getDimMap(indexingMaps, getIteratorTypes(),
|
||||
getParallelIteratorTypeName(), getContext());
|
||||
return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
|
||||
getContext());
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
|
@ -986,13 +987,13 @@ struct MultiReduceToContract
|
|||
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
|
||||
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
|
||||
SmallVector<AffineExpr> exprs;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
SmallVector<vector::IteratorType> iteratorTypes;
|
||||
for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
|
||||
if (!isReduceDim.value()) {
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
iteratorTypes.push_back(vector::IteratorType::parallel);
|
||||
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
|
||||
} else {
|
||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||
iteratorTypes.push_back(vector::IteratorType::reduction);
|
||||
}
|
||||
}
|
||||
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
|
||||
|
@ -1000,7 +1001,10 @@ struct MultiReduceToContract
|
|||
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
|
||||
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
|
||||
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
|
||||
rewriter.getStrArrayAttr(iteratorTypes));
|
||||
rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
|
||||
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
|
||||
return IteratorTypeAttr::get(rewriter.getContext(), t);
|
||||
}))));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue