[mlir][Linalg] Refactor and improve vectorization to add support for reduction into 0-d tensors.

This revision takes advantage of the recently added support for 0-d transfers and vector.multi_reduction that return a scalar.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111626
This commit is contained in:
Nicolas Vasilache 2021-10-12 12:39:26 +00:00
parent bdd37c9f49
commit 753a67b5c9
4 changed files with 195 additions and 88 deletions

View File

@ -1288,7 +1288,7 @@ def Vector_TransferReadOp :
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "AffineMap":$permutationMap,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
// Builder that sets padding to 'getMinorIdentityMap'.
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "Value":$padding,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
@ -1306,6 +1306,17 @@ def Vector_TransferReadOp :
"ArrayAttr":$inBounds)>
];
let extraClassDeclaration = [{
/// Temporary convenience builders to account for the fact that we do not
/// have 0-d vectors atm. These create a constant `vector<1xt>` and
/// insert/extract into it.
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
ValueRange indices,
ArrayRef<bool> inBounds = ArrayRef<bool>{});
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@ -1416,11 +1427,12 @@ def Vector_TransferWriteOp :
}];
let builders = [
// Builder that sets an empty mask.
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
// Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
@ -1429,6 +1441,18 @@ def Vector_TransferWriteOp :
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
];
let extraClassDeclaration = [{
/// Temporary convenience builders to account for the fact that we do not
/// have 0-d vectors atm. These create a constant `vector<1xt>` and
/// insert/extract into it.
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
static Operation *createScalarOp(
OpBuilder &builder, Location loc, Value value,
Value dest, ValueRange indices,
ArrayRef<bool> inBounds = ArrayRef<bool>{});
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

View File

@ -40,6 +40,9 @@ using llvm::dbgs;
#define DEBUG_TYPE "linalg-vectorization"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
@ -106,7 +109,7 @@ struct VectorizationResult {
/// ShapedType of `v`.
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
if (st.isa<MemRefType>() && st.getShape().empty())
if (st.getShape().empty())
return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
@ -163,16 +166,23 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// If value of assumed VectorType has a shape different than `shape`, build and
/// return a new vector.broadcast to `shape`.
/// Otherwise, just return value.
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
Value value, OpOperand *outputOperand) {
/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
/// whether a reduction is needed to produce a `targetType` and create that
/// reduction if it is the case.
static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
OpOperand *outputOperand) {
LDBG("Reduce " << value << " to type " << targetType);
LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
<< *(outputOperand->getOwner()));
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getShape() == targetVectorType.getShape())
VectorType targetVectorType = targetType.dyn_cast<VectorType>();
if (!vecType)
return value;
if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
return value;
// At this point, we know we need to reduce. Detect the reduction operator.
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
@ -181,7 +191,6 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
exprs.push_back(getAffineDimExpr(pos++, ctx));
auto loc = value.getLoc();
// At this point, we know we need to reduce. Detect the reduction operator.
auto maybeKind = matchLinalgReduction(outputOperand);
assert(maybeKind && "Failed precondition: could not get reduction kind");
unsigned idx = 0;
@ -196,16 +205,18 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
}
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build an memref.load.
/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
/// Return the produced value.
static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
AffineMap map) {
Location loc = source.getLoc();
auto shapedType = source.getType().cast<ShapedType>();
SmallVector<Value> indices(shapedType.getRank(),
b.create<ConstantIndexOp>(loc, 0));
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
map);
if (auto vectorType = readType.dyn_cast<VectorType>())
return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
map);
return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
}
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@ -216,13 +227,14 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand *outputOperand) {
Operation *write;
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
if (VectorType vectorType =
extractVectorTypeFromShapedValue(outputOperand->get())) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap map =
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
SmallVector<int64_t> transposeShape =
applyPermutationMap(inversePermutation(map), vectorType.getShape());
assert(!transposeShape.empty() && "unexpected empty transpose shape");
vectorType = VectorType::get(transposeShape, vectorType.getElementType());
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
@ -231,9 +243,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
value =
reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
write = vector::TransferWriteOp::createScalarOp(
b, loc, value, outputOperand->get(), ValueRange{});
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
LDBG("vectorized op: " << *write);
if (!write->getResults().empty())
return write->getResult(0);
return Value();
@ -329,7 +344,7 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
static VectorizationResult
vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
LDBG("vectorize op " << *op);
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
@ -466,33 +481,27 @@ LogicalResult vectorizeAsLinalgGeneric(
continue;
}
// TODO: 0-d vectors.
if (linalgOp.getShape(opOperand).empty()) {
Value loaded =
b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << loaded);
bvm.map(bbarg, loaded);
bvm.map(opOperand->get(), loaded);
continue;
}
Type readType;
AffineMap map;
VectorType vectorType;
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand));
vectorType = VectorType::get(commonVectorShape,
getElementTypeOrSelf(opOperand->get()));
if (linalgOp.getShape(opOperand).empty()) {
readType = bbarg.getType();
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand));
readType = VectorType::get(commonVectorShape,
getElementTypeOrSelf(opOperand->get()));
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get()));
}
}
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << vectorRead);
bvm.map(bbarg, vectorRead);
bvm.map(opOperand->get(), vectorRead);
Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
bvm.map(bbarg, readValue);
bvm.map(opOperand->get(), readValue);
}
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@ -516,12 +525,11 @@ LogicalResult vectorizeAsLinalgGeneric(
for (Operation &op : block.getOperations()) {
VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
LDBG("failed to vectorize: " << op);
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
<< *result.newOp;);
LDBG("new vector op: " << *result.newOp;);
bvm.map(op.getResults(), result.newOp->getResults());
}
}
@ -536,9 +544,9 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
Location loc = linalgOp.getLoc();
// Vectorize other ops as vector contraction.
// TODO: interface.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg op as vector.contract: ";
linalgOp.dump());
LDBG(""
<< "Rewrite linalg op as vector.contract: ";
linalgOp.dump());
// Special function that describes how to vectorize the multiplication op in a
// linalg contraction.
CustomVectorizationHook vectorizeContraction =
@ -592,11 +600,15 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.iterator_types(), isReductionIterator))
if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
LDBG("reduction precondition failed: no reduction iterator");
return failure();
}
for (OpOperand *opOperand : op.getOutputOperands()) {
if (!matchLinalgReduction(opOperand))
if (!matchLinalgReduction(opOperand)) {
LDBG("reduction precondition failed: reduction detection failed");
return failure();
}
}
return success();
}
@ -604,8 +616,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
if (linalgOp.hasDynamicShape())
if (linalgOp.hasDynamicShape()) {
LDBG("precondition failed: dynamic shape");
return failure();
}
if (isElementwise(op))
return success();
if (isaContractionOpInterface(linalgOp))
@ -613,10 +627,15 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
if (allIndexingsAreProjectedPermutation(linalgOp) &&
succeeded(reductionPreconditions(linalgOp)))
return success();
return failure();
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
LDBG("precondition failed: not projected permutations");
return failure();
}
if (failed(reductionPreconditions(linalgOp))) {
LDBG("precondition failed: reduction preconditions");
return failure();
}
return success();
}
LogicalResult
@ -629,10 +648,10 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
if (isaContractionOpInterface(linalgOp))
return vectorizeContraction(b, linalgOp, newResults);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic by broadcasting to "
"maximal common shape: "
<< *op);
LDBG(""
<< "Vectorize linalg op as a generic by broadcasting to "
"maximal common shape: "
<< *op);
return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
/*broadcastToMaximalCommonShape=*/true);
}
@ -1200,9 +1219,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "interleavedUses precondition failed, firstOp: "
<< *firstOp << ", second op: " << *secondOp);
LDBG("interleavedUses precondition failed, firstOp: "
<< *firstOp << ", second op: " << *secondOp);
return true;
}
for (auto v : values) {
@ -1214,10 +1232,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
if (owner->getBlock() == firstOp->getBlock() &&
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
LLVM_DEBUG(llvm::dbgs()
<< "\n[" DEBUG_TYPE "]: "
<< " found interleaved op " << *owner
<< ", firstOp: " << *firstOp << ", second op: " << *secondOp);
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
<< ", second op: " << *secondOp);
return true;
}
}
@ -1248,15 +1264,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
LDBG(viewOrAlloc);
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "with subView " << subView);
LDBG("with subView " << subView);
// Find the copy into `subView` without interleaved uses.
CopyOp copyOp;
@ -1265,8 +1280,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
assert(newCopyOp.output().getType().isa<MemRefType>());
if (newCopyOp.output() != subView)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "copy candidate " << *newCopyOp);
LDBG("copy candidate " << *newCopyOp);
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
@ -1275,8 +1289,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
}
if (!copyOp)
return failure();
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "with copy " << *copyOp);
LDBG("with copy " << *copyOp);
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
FillOp maybeFillOp;
@ -1285,8 +1298,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
assert(newFillOp.output().getType().isa<MemRefType>());
if (newFillOp.output() != viewOrAlloc)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "fill candidate " << *newFillOp);
LDBG("fill candidate " << *newFillOp);
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
continue;
maybeFillOp = newFillOp;
@ -1297,8 +1309,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
return failure();
if (maybeFillOp)
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "with maybeFillOp " << *maybeFillOp);
LDBG("with maybeFillOp " << *maybeFillOp);
// `in` is the subview that linalg.copy reads. Replace it.
Value in = copyOp.input();

View File

@ -2439,6 +2439,18 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
/*mask=*/Value(), inBounds);
}
Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
Value source, ValueRange indices,
ArrayRef<bool> inBounds) {
Type elemType = source.getType().cast<ShapedType>().getElementType();
auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, loc.getContext()));
Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
indices, map, inBounds);
return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
}
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 3> elidedAttrs;
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
@ -2769,6 +2781,16 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
// TransferWriteOp
//===----------------------------------------------------------------------===//
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMap permutationMap, ArrayRef<bool> inBounds) {
if (inBounds.empty())
return build(builder, result, vector, dest, indices, permutationMap,
/*mask=*/Value(), ArrayAttr());
build(builder, result, vector, dest, indices, permutationMap,
/*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
}
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
@ -2783,13 +2805,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMap permutationMap) {
build(builder, result, vector, source, indices, permutationMap,
/*inBounds=*/ArrayAttr());
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value source, ValueRange indices,
AffineMapAttr permutationMap,
@ -2817,6 +2832,20 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
mask, inBounds);
}
Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
Value value, Value dest,
ValueRange indices,
ArrayRef<bool> inBounds) {
Value vectorOfAScalar = value;
if (!value.getType().isa<VectorType>())
vectorOfAScalar = builder.create<vector::BroadcastOp>(
loc, VectorType::get({1}, value.getType()), value);
AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, loc.getContext()));
return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
indices, map, inBounds);
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();

View File

@ -203,8 +203,9 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
// CHECK: store %[[V]], %[[M]][] : memref<f32>
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
// CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
// CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
linalg.fill(%arg0, %A) : f32, memref<f32>
return
}
@ -223,8 +224,11 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
// CHECK-LABEL: func @test_vectorize_copy_scalar
func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
// CHECK: %[[V:.*]] = memref.load {{.*}} : memref<f32>
// CHECK: store %[[V]], {{.*}} : memref<f32>
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
// CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
// CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
// CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
linalg.copy(%A, %B) : memref<f32>, memref<f32>
return
}
@ -857,3 +861,42 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
return %red : tensor<4xf32>
}
// -----
// CHECK-LABEL: func @reduce_1d(
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK-DAG: %[[F0_v1:.*]] = constant dense<0.000000e+00> : vector<1xf32>
// CHECK-DAG: %[[F0_v32:.*]] = constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
%f0 = constant 0.000000e+00 : f32
// CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
%0 = linalg.init_tensor [] : tensor<f32>
// CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
// CHECK-SAME: : vector<1xf32>, tensor<f32>
%1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[a:.*]] = addf %[[r]], %[[F0_v32]] : vector<32xf32>
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[a]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<1xf32>
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
// CHECK-SAME: : vector<1xf32>, tensor<f32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>],
iterator_types = ["reduction"]}
ins(%arg0 : tensor<32xf32>)
outs(%1 : tensor<f32>) {
^bb0(%a: f32, %b: f32): // no predecessors
%3 = addf %a, %b : f32
linalg.yield %3 : f32
} -> tensor<f32>
return %2 : tensor<f32>
}