forked from OSchip/llvm-project
[mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.
This revision takes advantage of the recently added support for 0-d transfers and vector.multi_reduction that return a scalar. Reviewed By: pifon2a Differential Revision: https://reviews.llvm.org/D111626
This commit is contained in:
parent
bdd37c9f49
commit
753a67b5c9
|
@ -1288,7 +1288,7 @@ def Vector_TransferReadOp :
|
|||
OpBuilder<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "AffineMap":$permutationMap,
|
||||
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
|
||||
// Builder that sets padding to 'getMinorIdentityMap'.
|
||||
// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
OpBuilder<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "Value":$padding,
|
||||
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
|
||||
|
@ -1306,6 +1306,17 @@ def Vector_TransferReadOp :
|
|||
"ArrayAttr":$inBounds)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Temporary convenience builders to account for the fact that we do not
|
||||
/// have 0-d vectors atm. These create a constant `vector<1xt>` and
|
||||
/// insert/extract into it.
|
||||
// Builder that sets permutation map (resp. padding) to
|
||||
// 'getMinorIdentityMap' (resp. zero).
|
||||
static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
|
||||
ValueRange indices,
|
||||
ArrayRef<bool> inBounds = ArrayRef<bool>{});
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
@ -1416,11 +1427,12 @@ def Vector_TransferWriteOp :
|
|||
}];
|
||||
|
||||
let builders = [
|
||||
// Builder that sets an empty mask.
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
|
||||
// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
|
@ -1429,6 +1441,18 @@ def Vector_TransferWriteOp :
|
|||
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Temporary convenience builders to account for the fact that we do not
|
||||
/// have 0-d vectors atm. These create a constant `vector<1xt>` and
|
||||
/// insert/extract into it.
|
||||
// Builder that sets permutation map (resp. padding) to
|
||||
// 'getMinorIdentityMap' (resp. zero).
|
||||
static Operation *createScalarOp(
|
||||
OpBuilder &builder, Location loc, Value value,
|
||||
Value dest, ValueRange indices,
|
||||
ArrayRef<bool> inBounds = ArrayRef<bool>{});
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,9 @@ using llvm::dbgs;
|
|||
|
||||
#define DEBUG_TYPE "linalg-vectorization"
|
||||
|
||||
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
|
||||
|
||||
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
||||
/// Return null if none or more than 1 instances exist.
|
||||
template <typename OpType>
|
||||
|
@ -106,7 +109,7 @@ struct VectorizationResult {
|
|||
/// ShapedType of `v`.
|
||||
static VectorType extractVectorTypeFromShapedValue(Value v) {
|
||||
auto st = v.getType().cast<ShapedType>();
|
||||
if (st.isa<MemRefType>() && st.getShape().empty())
|
||||
if (st.getShape().empty())
|
||||
return VectorType();
|
||||
return VectorType::get(st.getShape(), st.getElementType());
|
||||
}
|
||||
|
@ -163,16 +166,23 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
|
|||
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
|
||||
}
|
||||
|
||||
/// If value of assumed VectorType has a shape different than `shape`, build and
|
||||
/// return a new vector.broadcast to `shape`.
|
||||
/// Otherwise, just return value.
|
||||
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
||||
Value value, OpOperand *outputOperand) {
|
||||
/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
|
||||
/// whether a reduction is needed to produce a `targetType` and create that
|
||||
/// reduction if it is the case.
|
||||
static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
|
||||
OpOperand *outputOperand) {
|
||||
LDBG("Reduce " << value << " to type " << targetType);
|
||||
LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
|
||||
<< *(outputOperand->getOwner()));
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
auto vecType = value.getType().dyn_cast<VectorType>();
|
||||
if (!vecType || vecType.getShape() == targetVectorType.getShape())
|
||||
VectorType targetVectorType = targetType.dyn_cast<VectorType>();
|
||||
if (!vecType)
|
||||
return value;
|
||||
if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
|
||||
return value;
|
||||
|
||||
// At this point, we know we need to reduce. Detect the reduction operator.
|
||||
unsigned pos = 0;
|
||||
MLIRContext *ctx = b.getContext();
|
||||
SmallVector<AffineExpr> exprs;
|
||||
|
@ -181,7 +191,6 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
|||
exprs.push_back(getAffineDimExpr(pos++, ctx));
|
||||
auto loc = value.getLoc();
|
||||
|
||||
// At this point, we know we need to reduce. Detect the reduction operator.
|
||||
auto maybeKind = matchLinalgReduction(outputOperand);
|
||||
assert(maybeKind && "Failed precondition: could not get reduction kind");
|
||||
unsigned idx = 0;
|
||||
|
@ -196,16 +205,18 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
|||
}
|
||||
|
||||
/// Build a vector.transfer_read from `source` at indices set to all `0`.
|
||||
/// If source has rank zero, build an memref.load.
|
||||
/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
|
||||
/// Return the produced value.
|
||||
static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
|
||||
static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
|
||||
AffineMap map) {
|
||||
Location loc = source.getLoc();
|
||||
auto shapedType = source.getType().cast<ShapedType>();
|
||||
SmallVector<Value> indices(shapedType.getRank(),
|
||||
b.create<ConstantIndexOp>(loc, 0));
|
||||
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
|
||||
map);
|
||||
if (auto vectorType = readType.dyn_cast<VectorType>())
|
||||
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
|
||||
map);
|
||||
return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
|
||||
}
|
||||
|
||||
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
|
||||
|
@ -216,13 +227,14 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
|
|||
OpOperand *outputOperand) {
|
||||
Operation *write;
|
||||
Location loc = value.getLoc();
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
if (VectorType vectorType =
|
||||
extractVectorTypeFromShapedValue(outputOperand->get())) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
AffineMap map =
|
||||
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
|
||||
SmallVector<int64_t> transposeShape =
|
||||
applyPermutationMap(inversePermutation(map), vectorType.getShape());
|
||||
assert(!transposeShape.empty() && "unexpected empty transpose shape");
|
||||
vectorType = VectorType::get(transposeShape, vectorType.getElementType());
|
||||
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
|
||||
b.create<ConstantIndexOp>(loc, 0));
|
||||
|
@ -231,9 +243,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
|
|||
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
|
||||
indices, map);
|
||||
} else {
|
||||
write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
|
||||
value =
|
||||
reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
|
||||
write = vector::TransferWriteOp::createScalarOp(
|
||||
b, loc, value, outputOperand->get(), ValueRange{});
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
|
||||
LDBG("vectorized op: " << *write);
|
||||
if (!write->getResults().empty())
|
||||
return write->getResult(0);
|
||||
return Value();
|
||||
|
@ -329,7 +344,7 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
|
|||
static VectorizationResult
|
||||
vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
|
||||
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
|
||||
LDBG("vectorize op " << *op);
|
||||
|
||||
// 1. Try to apply any CustomVectorizationHook.
|
||||
if (!customVectorizationHooks.empty()) {
|
||||
|
@ -466,33 +481,27 @@ LogicalResult vectorizeAsLinalgGeneric(
|
|||
continue;
|
||||
}
|
||||
// TODO: 0-d vectors.
|
||||
if (linalgOp.getShape(opOperand).empty()) {
|
||||
Value loaded =
|
||||
b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
||||
<< bbarg.getArgNumber() << "): " << loaded);
|
||||
bvm.map(bbarg, loaded);
|
||||
bvm.map(opOperand->get(), loaded);
|
||||
continue;
|
||||
}
|
||||
Type readType;
|
||||
AffineMap map;
|
||||
VectorType vectorType;
|
||||
if (broadcastToMaximalCommonShape) {
|
||||
map = inverseAndBroadcastProjectedPermuation(
|
||||
linalgOp.getTiedIndexingMap(opOperand));
|
||||
vectorType = VectorType::get(commonVectorShape,
|
||||
getElementTypeOrSelf(opOperand->get()));
|
||||
if (linalgOp.getShape(opOperand).empty()) {
|
||||
readType = bbarg.getType();
|
||||
} else {
|
||||
map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
|
||||
vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
|
||||
if (broadcastToMaximalCommonShape) {
|
||||
map = inverseAndBroadcastProjectedPermuation(
|
||||
linalgOp.getTiedIndexingMap(opOperand));
|
||||
readType = VectorType::get(commonVectorShape,
|
||||
getElementTypeOrSelf(opOperand->get()));
|
||||
} else {
|
||||
map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
|
||||
readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
|
||||
getElementTypeOrSelf(opOperand->get()));
|
||||
}
|
||||
}
|
||||
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
||||
<< bbarg.getArgNumber() << "): " << vectorRead);
|
||||
bvm.map(bbarg, vectorRead);
|
||||
bvm.map(opOperand->get(), vectorRead);
|
||||
Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
|
||||
LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
|
||||
bvm.map(bbarg, readValue);
|
||||
bvm.map(opOperand->get(), readValue);
|
||||
}
|
||||
|
||||
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
|
||||
|
@ -516,12 +525,11 @@ LogicalResult vectorizeAsLinalgGeneric(
|
|||
for (Operation &op : block.getOperations()) {
|
||||
VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
|
||||
if (result.status == VectorizationStatus::Failure) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
|
||||
LDBG("failed to vectorize: " << op);
|
||||
return failure();
|
||||
}
|
||||
if (result.status == VectorizationStatus::NewOp) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
|
||||
<< *result.newOp;);
|
||||
LDBG("new vector op: " << *result.newOp;);
|
||||
bvm.map(op.getResults(), result.newOp->getResults());
|
||||
}
|
||||
}
|
||||
|
@ -536,9 +544,9 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
|
|||
Location loc = linalgOp.getLoc();
|
||||
// Vectorize other ops as vector contraction.
|
||||
// TODO: interface.
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "Rewrite linalg op as vector.contract: ";
|
||||
linalgOp.dump());
|
||||
LDBG(""
|
||||
<< "Rewrite linalg op as vector.contract: ";
|
||||
linalgOp.dump());
|
||||
// Special function that describes how to vectorize the multiplication op in a
|
||||
// linalg contraction.
|
||||
CustomVectorizationHook vectorizeContraction =
|
||||
|
@ -592,11 +600,15 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
|||
// TODO: probably need some extra checks for reduction followed by consumer
|
||||
// ops that may not commute (e.g. linear reduction + non-linear instructions).
|
||||
static LogicalResult reductionPreconditions(LinalgOp op) {
|
||||
if (llvm::none_of(op.iterator_types(), isReductionIterator))
|
||||
if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
|
||||
LDBG("reduction precondition failed: no reduction iterator");
|
||||
return failure();
|
||||
}
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
if (!matchLinalgReduction(opOperand))
|
||||
if (!matchLinalgReduction(opOperand)) {
|
||||
LDBG("reduction precondition failed: reduction detection failed");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -604,8 +616,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
|
|||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
// All types must be static shape to go to vector.
|
||||
if (linalgOp.hasDynamicShape())
|
||||
if (linalgOp.hasDynamicShape()) {
|
||||
LDBG("precondition failed: dynamic shape");
|
||||
return failure();
|
||||
}
|
||||
if (isElementwise(op))
|
||||
return success();
|
||||
if (isaContractionOpInterface(linalgOp))
|
||||
|
@ -613,10 +627,15 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
|||
// TODO: the common vector shape is equal to the static loop sizes only when
|
||||
// all indexing maps are projected permutations. For convs and stencils the
|
||||
// logic will need to evolve.
|
||||
if (allIndexingsAreProjectedPermutation(linalgOp) &&
|
||||
succeeded(reductionPreconditions(linalgOp)))
|
||||
return success();
|
||||
return failure();
|
||||
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
|
||||
LDBG("precondition failed: not projected permutations");
|
||||
return failure();
|
||||
}
|
||||
if (failed(reductionPreconditions(linalgOp))) {
|
||||
LDBG("precondition failed: reduction preconditions");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -629,10 +648,10 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
|
|||
if (isaContractionOpInterface(linalgOp))
|
||||
return vectorizeContraction(b, linalgOp, newResults);
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "Vectorize linalg op as a generic by broadcasting to "
|
||||
"maximal common shape: "
|
||||
<< *op);
|
||||
LDBG(""
|
||||
<< "Vectorize linalg op as a generic by broadcasting to "
|
||||
"maximal common shape: "
|
||||
<< *op);
|
||||
return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
|
||||
/*broadcastToMaximalCommonShape=*/true);
|
||||
}
|
||||
|
@ -1200,9 +1219,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|||
ValueRange values) {
|
||||
if (firstOp->getBlock() != secondOp->getBlock() ||
|
||||
!firstOp->isBeforeInBlock(secondOp)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "interleavedUses precondition failed, firstOp: "
|
||||
<< *firstOp << ", second op: " << *secondOp);
|
||||
LDBG("interleavedUses precondition failed, firstOp: "
|
||||
<< *firstOp << ", second op: " << *secondOp);
|
||||
return true;
|
||||
}
|
||||
for (auto v : values) {
|
||||
|
@ -1214,10 +1232,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|||
if (owner->getBlock() == firstOp->getBlock() &&
|
||||
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "\n[" DEBUG_TYPE "]: "
|
||||
<< " found interleaved op " << *owner
|
||||
<< ", firstOp: " << *firstOp << ", second op: " << *secondOp);
|
||||
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
|
||||
<< ", second op: " << *secondOp);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -1248,15 +1264,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||
return failure();
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
|
||||
LDBG(viewOrAlloc);
|
||||
|
||||
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
||||
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
||||
if (!subViewOp)
|
||||
return failure();
|
||||
Value subView = subViewOp.getResult();
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "with subView " << subView);
|
||||
LDBG("with subView " << subView);
|
||||
|
||||
// Find the copy into `subView` without interleaved uses.
|
||||
CopyOp copyOp;
|
||||
|
@ -1265,8 +1280,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
assert(newCopyOp.output().getType().isa<MemRefType>());
|
||||
if (newCopyOp.output() != subView)
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "copy candidate " << *newCopyOp);
|
||||
LDBG("copy candidate " << *newCopyOp);
|
||||
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
|
||||
continue;
|
||||
copyOp = newCopyOp;
|
||||
|
@ -1275,8 +1289,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
}
|
||||
if (!copyOp)
|
||||
return failure();
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "with copy " << *copyOp);
|
||||
LDBG("with copy " << *copyOp);
|
||||
|
||||
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
|
||||
FillOp maybeFillOp;
|
||||
|
@ -1285,8 +1298,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
assert(newFillOp.output().getType().isa<MemRefType>());
|
||||
if (newFillOp.output() != viewOrAlloc)
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "fill candidate " << *newFillOp);
|
||||
LDBG("fill candidate " << *newFillOp);
|
||||
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
|
||||
continue;
|
||||
maybeFillOp = newFillOp;
|
||||
|
@ -1297,8 +1309,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
|
||||
return failure();
|
||||
if (maybeFillOp)
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "with maybeFillOp " << *maybeFillOp);
|
||||
LDBG("with maybeFillOp " << *maybeFillOp);
|
||||
|
||||
// `in` is the subview that linalg.copy reads. Replace it.
|
||||
Value in = copyOp.input();
|
||||
|
|
|
@ -2439,6 +2439,18 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
|||
/*mask=*/Value(), inBounds);
|
||||
}
|
||||
|
||||
Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
|
||||
Value source, ValueRange indices,
|
||||
ArrayRef<bool> inBounds) {
|
||||
Type elemType = source.getType().cast<ShapedType>().getElementType();
|
||||
auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
|
||||
AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
|
||||
getAffineConstantExpr(0, loc.getContext()));
|
||||
Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
|
||||
indices, map, inBounds);
|
||||
return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
|
||||
}
|
||||
|
||||
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
||||
SmallVector<StringRef, 3> elidedAttrs;
|
||||
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
|
||||
|
@ -2769,6 +2781,16 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// TransferWriteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value dest, ValueRange indices,
|
||||
AffineMap permutationMap, ArrayRef<bool> inBounds) {
|
||||
if (inBounds.empty())
|
||||
return build(builder, result, vector, dest, indices, permutationMap,
|
||||
/*mask=*/Value(), ArrayAttr());
|
||||
build(builder, result, vector, dest, indices, permutationMap,
|
||||
/*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
|
||||
}
|
||||
|
||||
/// Builder that sets permutation map to 'getMinorIdentityMap'.
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
|
@ -2783,13 +2805,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
|||
build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap) {
|
||||
build(builder, result, vector, source, indices, permutationMap,
|
||||
/*inBounds=*/ArrayAttr());
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMapAttr permutationMap,
|
||||
|
@ -2817,6 +2832,20 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
|||
mask, inBounds);
|
||||
}
|
||||
|
||||
Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
|
||||
Value value, Value dest,
|
||||
ValueRange indices,
|
||||
ArrayRef<bool> inBounds) {
|
||||
Value vectorOfAScalar = value;
|
||||
if (!value.getType().isa<VectorType>())
|
||||
vectorOfAScalar = builder.create<vector::BroadcastOp>(
|
||||
loc, VectorType::get({1}, value.getType()), value);
|
||||
AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
|
||||
getAffineConstantExpr(0, loc.getContext()));
|
||||
return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
|
||||
indices, map, inBounds);
|
||||
}
|
||||
|
||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
|
|
|
@ -203,8 +203,9 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
|
|||
|
||||
// CHECK-LABEL: func @test_vectorize_fill
|
||||
func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
|
||||
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
|
||||
// CHECK: store %[[V]], %[[M]][] : memref<f32>
|
||||
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
|
||||
// CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
|
||||
// CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
|
||||
linalg.fill(%arg0, %A) : f32, memref<f32>
|
||||
return
|
||||
}
|
||||
|
@ -223,8 +224,11 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @test_vectorize_copy_scalar
|
||||
func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
|
||||
// CHECK: %[[V:.*]] = memref.load {{.*}} : memref<f32>
|
||||
// CHECK: store %[[V]], {{.*}} : memref<f32>
|
||||
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
|
||||
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
|
||||
// CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
|
||||
// CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
|
||||
// CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
|
||||
linalg.copy(%A, %B) : memref<f32>, memref<f32>
|
||||
return
|
||||
}
|
||||
|
@ -857,3 +861,42 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
|||
return %red : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_1d(
|
||||
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
|
||||
func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
|
||||
// CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32>
|
||||
// CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32>
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
%f0 = constant 0.000000e+00 : f32
|
||||
|
||||
// CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
|
||||
%0 = linalg.init_tensor [] : tensor<f32>
|
||||
|
||||
// CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
|
||||
// CHECK-SAME: : vector<1xf32>, tensor<f32>
|
||||
%1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
|
||||
|
||||
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
|
||||
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
|
||||
// CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32>
|
||||
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[a]] [0]
|
||||
// CHECK-SAME: : vector<32xf32> to f32
|
||||
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32>
|
||||
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
|
||||
// CHECK-SAME: : vector<1xf32>, tensor<f32>
|
||||
%2 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> ()>],
|
||||
iterator_types = ["reduction"]}
|
||||
ins(%arg0 : tensor<32xf32>)
|
||||
outs(%1 : tensor<f32>) {
|
||||
^bb0(%a: f32, %b: f32): // no predecessors
|
||||
%3 = addf %a, %b : f32
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<f32>
|
||||
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue