diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 9a29825eda50..212084a75d4c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -318,6 +318,7 @@ def Vector_ReductionOp : def Vector_MultiDimReductionOp : Vector_Op<"multi_reduction", [NoSideEffect, + AllTypesMatch<["dest", "acc"]>, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods, @@ -325,6 +326,7 @@ def Vector_MultiDimReductionOp : ["getShapeForUnroll"]>]>, Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$source, + AnyType:$acc, I64ArrayAttr:$reduction_dims)>, Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; @@ -332,19 +334,20 @@ def Vector_MultiDimReductionOp : Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). + Takes an initial accumulator operand. Example: ```mlir - %1 = vector.multi_reduction , %0 [1, 3] : + %1 = vector.multi_reduction , %0, %acc0 [1, 3] : vector<4x8x16x32xf32> into vector<4x16xf32> - %2 = vector.multi_reduction , %1 [0, 1] : + %2 = vector.multi_reduction , %1, %acc1 [0, 1] : vector<4x16xf32> into f32 ``` }]; let builders = [ - OpBuilder<(ins "Value":$source, "ArrayRef":$reductionMask, - "CombiningKind":$kind)> + OpBuilder<(ins "Value":$source, "Value":$acc, + "ArrayRef":$reductionMask, "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ static StringRef getKindAttrStrName() { return "kind"; } @@ -378,8 +381,9 @@ def Vector_MultiDimReductionOp : } }]; let assemblyFormat = - "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; + "$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasVerifier = 1; } def Vector_BroadcastOp : diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index c406dc5de88c..3422ab7c0a76 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -174,13 +174,13 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This /// assumes that `reductionOp` has two operands and one of them is the reduction /// initial value. -static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, - Value valueToReduce, - const SmallVector &reductionMask) { +static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, + Value valueToReduce, Value acc, + const SmallVector &reductionMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); return b.create( - reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind); + reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind); } static SmallVector getReductionMask(LinalgOp linalgOp) { @@ -315,10 +315,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, (outputType && reduceType.getShape() == outputType.getShape())) return nullptr; SmallVector reductionMask = getReductionMask(linalgOp); - Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask); - return b.create(op->getLoc(), op->getName().getIdentifier(), - /*operands=*/{reduce, outputVec}, reduce.getType(), - op->getAttrs()); + return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask); } /// Generic vectorization for a single operation `op`, given already vectorized diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 00db4650f120..f803868c2150 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -334,34 +334,14 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, void vector::MultiDimReductionOp::build(OpBuilder &builder, OperationState &result, Value source, - ArrayRef reductionMask, + Value acc, ArrayRef reductionMask, CombiningKind kind) { SmallVector reductionDims; for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims)); -} - -LogicalResult MultiDimReductionOp::inferReturnTypes( - MLIRContext *, Optional, ValueRange operands, - DictionaryAttr attributes, RegionRange, - SmallVectorImpl &inferredReturnTypes) { - MultiDimReductionOp::Adaptor op(operands, attributes); - auto vectorType = op.getSource().getType().cast(); - SmallVector targetShape; - for (auto it : llvm::enumerate(vectorType.getShape())) - if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) { - return attr.cast().getValue() == it.index(); - })) - targetShape.push_back(it.value()); - // TODO: update to also allow 0-d vectors when available. - if (targetShape.empty()) - inferredReturnTypes.push_back(vectorType.getElementType()); - else - inferredReturnTypes.push_back( - VectorType::get(targetShape, vectorType.getElementType())); - return success(); + build(builder, result, kind, source, acc, + builder.getI64ArrayAttr(reductionDims)); } OpFoldResult MultiDimReductionOp::fold(ArrayRef operands) { @@ -375,6 +355,28 @@ Optional> MultiDimReductionOp::getShapeForUnroll() { return llvm::to_vector<4>(getSourceVectorType().getShape()); } +LogicalResult MultiDimReductionOp::verify() { + SmallVector targetShape; + Type inferredReturnType; + for (auto it : llvm::enumerate(getSourceVectorType().getShape())) + if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { + return attr.cast().getValue() == it.index(); + })) + targetShape.push_back(it.value()); + // TODO: update to also allow 0-d vectors when available. + if (targetShape.empty()) + inferredReturnType = getSourceVectorType().getElementType(); + else + inferredReturnType = + VectorType::get(targetShape, getSourceVectorType().getElementType()); + if (getType() != inferredReturnType) + return emitOpError() << "destination type " << getType() + << " is incompatible with source type " + << getSourceVectorType(); + + return success(); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp index 0e023ca44832..2582781aaab0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -87,8 +87,8 @@ public: reductionMask[i] = true; } rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.getResult(), reductionMask, - multiReductionOp.getKind()); + multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(), + reductionMask, multiReductionOp.getKind()); return success(); } @@ -188,11 +188,17 @@ public: vectorShape, multiReductionOp.getSourceVectorType().getElementType()); Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); - + Value acc = multiReductionOp.getAcc(); + if (flattenedParallelDim) { + auto accType = VectorType::get( + {flattenedParallelDim}, + multiReductionOp.getSourceVectorType().getElementType()); + acc = rewriter.create(loc, accType, acc); + } // 5. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. auto newOp = rewriter.create( - loc, cast, mask, multiReductionOp.getKind()); + loc, cast, acc, mask, multiReductionOp.getKind()); // 6. If there are no parallel shapes, the result is a scalar. // TODO: support 0-d vectors when available. @@ -238,10 +244,8 @@ struct TwoDimMultiReductionToElementWise if (!elementType.isIntOrIndexOrFloat()) return failure(); - Value result = - rewriter.create(loc, multiReductionOp.getSource(), 0) - .getResult(); - for (int64_t i = 1; i < srcShape[0]; i++) { + Value result = multiReductionOp.getAcc(); + for (int64_t i = 0; i < srcShape[0]; i++) { auto operand = rewriter.create( loc, multiReductionOp.getSource(), i); result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), @@ -277,8 +281,10 @@ struct TwoDimMultiReductionToReduction for (int i = 0; i < outerDim; ++i) { auto v = rewriter.create( loc, multiReductionOp.getSource(), ArrayRef{i}); + auto acc = rewriter.create( + loc, multiReductionOp.getAcc(), ArrayRef{i}); auto reducedValue = rewriter.create( - loc, multiReductionOp.getKind(), v); + loc, multiReductionOp.getKind(), v, acc); result = rewriter.create( loc, reducedValue, result, rewriter.create(loc, i)); @@ -309,6 +315,8 @@ struct OneDimMultiReductionToTwoDim auto srcShape = srcVectorType.getShape(); auto castedType = VectorType::get(ArrayRef{1, srcShape.back()}, srcVectorType.getElementType()); + auto accType = + VectorType::get(ArrayRef{1}, srcVectorType.getElementType()); assert(!multiReductionOp.getDestType().isa() && "multi_reduction with a single dimension expects a scalar result"); @@ -319,8 +327,10 @@ struct OneDimMultiReductionToTwoDim /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); + Value castAcc = rewriter.create( + loc, accType, multiReductionOp.getAcc()); Value reduced = rewriter.create( - loc, cast, mask, multiReductionOp.getKind()); + loc, cast, castAcc, mask, multiReductionOp.getKind()); rewriter.replaceOpWithNewOp(multiReductionOp, reduced, ArrayRef{0}); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index bb8cc2bfae39..76151fc358ad 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -997,11 +997,8 @@ struct MultiReduceToContract } auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), /*symCount=*/0, exprs, reduceOp.getContext()); - Value zero = rewriter.create( - reduceOp.getLoc(), reduceOp.getDestType(), - rewriter.getZeroAttr(reduceOp.getDestType())); rewriter.replaceOpWithNewOp( - reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, + reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), rewriter.getStrArrayAttr(iteratorTypes)); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp index d75d1098d53f..15f43dc0536c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -431,10 +431,11 @@ struct UnrollMultiReductionPattern SmallVector offsets = getVectorOffset(originalSize, *targetShape, i); + SmallVector operands; SmallVector operandStrides(offsets.size(), 1); Value slicedOperand = rewriter.create( - loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides); - + loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); + operands.push_back(slicedOperand); SmallVector dstShape; SmallVector destOffset; for (size_t i : llvm::seq(size_t(0), targetShape->size())) { @@ -443,17 +444,22 @@ struct UnrollMultiReductionPattern dstShape.push_back((*targetShape)[i]); } } + Value acc; + SmallVector accStrides(destOffset.size(), 1); + // If a version of the accumulator has already been computed, use it + // otherwise extract the first version from the original operand. + auto accIt = accCache.find(destOffset); + if (accIt != accCache.end()) + acc = accIt->second; + else + acc = rewriter.create( + loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); + operands.push_back(acc); auto targetType = VectorType::get( dstShape, reductionOp.getSourceVectorType().getElementType()); Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, - slicedOperand, targetType); + operands, targetType); Value result = newOp->getResult(0); - // Save the accumulated value until all the loops are unrolled since - // reduction loop keeps updating the accumulator. - auto accIt = accCache.find(destOffset); - if (accIt != accCache.end()) - result = makeArithReduction(rewriter, loc, reductionOp.getKind(), - result, accIt->second); accCache[destOffset] = result; } // Assemble back the accumulator into a single vector. diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir index b4dc8e432567..b08d7d1ff1c1 100644 --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -10,9 +10,8 @@ func.func @vectorize_matmul(%arg0: tensor<24x12xf32>, // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]] // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]] // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] - // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] - // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] - // CHECK: vector.transfer_write %[[vS]], %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]] + // CHECK: vector.transfer_write %[[vR]], %[[C]] %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } @@ -67,9 +66,8 @@ func.func @vectorize_keep_pad( // CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]] // CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]] // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] - // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] - // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] - // CHECK: vector.transfer_write %[[vS]], %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]] + // CHECK: vector.transfer_write %[[vR]], %[[C]] %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> return %9 : tensor<24x25xf32> @@ -127,9 +125,8 @@ func.func @vectorize_pad( tensor.yield %cst : f32 } : tensor to tensor<7x5xf32> // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] - // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] - // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] - // CHECK: vector.transfer_write %[[vS]], %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]] + // CHECK: vector.transfer_write %[[vR]], %[[C]] %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> return %9 : tensor<24x25xf32> diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index dbd09576cb76..bbc36b12556e 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -6,8 +6,7 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32> -// CHECK: vector.multi_reduction , %{{.*}} [0] : vector<1584xf32> to f32 -// CHECK: arith.addf %{{.*}}, %{{.*}} : f32 +// CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [0] : vector<1584xf32> to f32 linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) outs(%C: memref) return @@ -19,8 +18,7 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32> -// CHECK: vector.multi_reduction , %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32> +// CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32> linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) outs(%C: memref<1584xf32>) return @@ -31,8 +29,7 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, % // CHECK-LABEL: contraction_matmul func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> -// CHECK: vector.multi_reduction , %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32> linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) outs(%C: memref<1584x1584xf32>) return @@ -43,8 +40,7 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3 // CHECK-LABEL: contraction_batch_matmul func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> -// CHECK: vector.multi_reduction , %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> +// CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32> linalg.batch_matmul ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) outs(%C: memref<1584x1584x1584xf32>) @@ -69,10 +65,9 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> + // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> - // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) @@ -103,10 +98,9 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<32x8xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> + // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32> - // CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32> linalg.generic #matmul_transpose_out_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) @@ -157,11 +151,9 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32 %C: memref<8x32xi32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32> - // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> + // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32> - // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32> - // CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32> - + // CHECK: vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) @@ -180,8 +172,7 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32 func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> - // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32> + // CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> linalg.matmul ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) outs(%C: memref<8x32xf32>) @@ -560,9 +551,8 @@ func.func @matmul_tensors( // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later // convert it to a 2D contract. // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32> - // CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32> - // CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32> + // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) outs(%arg2: tensor<8x12xf32>) -> tensor<8x12xf32> @@ -801,8 +791,7 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32> // CHECK: math.exp {{.*}} : vector<4x16x8xf32> - // CHECK: vector.multi_reduction , %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> - // CHECK: addf {{.*}} : vector<4x16xf32> + // CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> // CHECK: return {{.*}} : tensor<4x16xf32> %0 = linalg.generic { @@ -836,8 +825,7 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> - // CHECK: vector.multi_reduction , {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> - // CHECK: addf {{.*}} : vector<2x5xf32> + // CHECK: vector.multi_reduction , {{.*}}, %{{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32> %0 = linalg.generic { @@ -865,8 +853,7 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> - // CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -890,8 +877,7 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: %[[R:.+]] = vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> - // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -914,7 +900,7 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xf32> to vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant 1.0 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -937,7 +923,7 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -960,7 +946,7 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant true %init = linalg.init_tensor [4] : tensor<4xi1> @@ -983,7 +969,7 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> - // CHECK: vector.multi_reduction , {{.*}} [1] : vector<4x4xi1> to vector<4xi1> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> %ident = arith.constant false %init = linalg.init_tensor [4] : tensor<4xi1> @@ -1035,8 +1021,7 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32> // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: math.exp {{.*}} : vector<4x4xf32> - // CHECK: vector.multi_reduction , {{.*}} : vector<4x4xf32> to vector<4xf32> - // CHECK: addf {{.*}} : vector<4xf32> + // CHECK: vector.multi_reduction , {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32> %c0 = arith.constant 0.0 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -1075,10 +1060,9 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector - // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]] [0] + // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]], %[[f0]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector + // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] // CHECK-SAME: : vector, tensor %2 = linalg.generic { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 84b5a45f19e6..702670095c8d 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1281,9 +1281,9 @@ func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>, // ----- // CHECK-LABEL: func @vector_multi_reduction_single_parallel( -// CHECK-SAME: %[[v:.*]]: vector<2xf32> -func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0 [] : vector<2xf32> to vector<2xf32> +// CHECK-SAME: %[[v:.*]]: vector<2xf32>, +func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction , %arg0, %acc [] : vector<2xf32> to vector<2xf32> // CHECK: return %[[v]] : vector<2xf32> return %0 : vector<2xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 87e5f9443807..d50315970d74 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1138,9 +1138,9 @@ func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 { // ----- -func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 { - // expected-error@+1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}} - %0 = vector.multi_reduction , %arg0 [1] : vector<4x16xf32> to vector<16xf32> +func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf32>) -> f32 { + // expected-error@+1 {{'vector.multi_reduction' op destination type 'vector<16xf32>' is incompatible with source type 'vector<4x16xf32>'}} + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<4x16xf32> to vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index e42c94f252c4..dc69bb0a78a6 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -705,10 +705,13 @@ func.func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>, } // CHECK-LABEL: @multi_reduction -func.func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 { - %1 = vector.multi_reduction , %0 [1, 3] : +func.func @multi_reduction(%0: vector<4x8x16x32xf32>, %acc0: vector<4x16xf32>, + %acc1: f32) -> f32 { + // CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32> + %1 = vector.multi_reduction , %0, %acc0 [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32> - %2 = vector.multi_reduction , %1 [0, 1] : + // CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x16xf32> to f32 + %2 = vector.multi_reduction , %1, %acc1 [0, 1] : vector<4x16xf32> to f32 return %2 : f32 } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index a39ac990f828..6b372c3ef1c3 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,40 +1,42 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s -func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> +func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>) // CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32> // CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] -// CHECK: %[[RV0:.+]] = vector.reduction , %[[V0]] : vector<4xf32> into f32 +// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0] +// CHECK: %[[RV0:.+]] = vector.reduction , %[[V0]], %[[ACC0]] : vector<4xf32> into f32 // CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] -// CHECK: %[[RV1:.+]] = vector.reduction , %[[V1]] : vector<4xf32> into f32 +// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1] +// CHECK: %[[RV1:.+]] = vector.reduction , %[[V1]], %[[ACC1]] : vector<4xf32> into f32 // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] -func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 { - %0 = vector.multi_reduction , %arg0 [0, 1] : vector<2x4xf32> to f32 +func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 { + %0 = vector.multi_reduction , %arg0, %acc [0, 1] : vector<2x4xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @vector_multi_reduction_to_scalar -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32) // CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32> -// CHECK: %[[REDUCED:.*]] = vector.reduction , %[[CASTED]] : vector<8xf32> into f32 +// CHECK: %[[REDUCED:.*]] = vector.reduction , %[[CASTED]], %[[ACC]] : vector<8xf32> into f32 // CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32> // CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32> // CHECK: return %[[RES]] -func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> +func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.multi_reduction , %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } // CHECK-LABEL: func @vector_reduction_inner -// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32> // CHECK: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32> // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index @@ -44,29 +46,35 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32> // CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32> -// CHECK: %[[V0R:.+]] = vector.reduction , %[[V0]] : vector<20xi32> into i32 +// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x3xi32> +// CHECK: %[[V0R:.+]] = vector.reduction , %[[V0]], %[[ACC0]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32> // CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32> -// CHECK: %[[V1R:.+]] = vector.reduction , %[[V1]] : vector<20xi32> into i32 +// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x3xi32> +// CHECK: %[[V1R:.+]] = vector.reduction , %[[V1]], %[[ACC1]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32> // CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32> -// CHECK: %[[V2R:.+]] = vector.reduction , %[[V2]] : vector<20xi32> into i32 +// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x3xi32> +// CHECK: %[[V2R:.+]] = vector.reduction , %[[V2]], %[[ACC2]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32> // CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32> -// CHECK: %[[V3R:.+]] = vector.reduction , %[[V3]] : vector<20xi32> into i32 +// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x3xi32> +// CHECK: %[[V3R:.+]] = vector.reduction , %[[V3]], %[[ACC3]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32> // CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32> -// CHECK: %[[V4R:.+]] = vector.reduction , %[[V4]] : vector<20xi32> into i32 +// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x3xi32> +// CHECK: %[[V4R:.+]] = vector.reduction , %[[V4]], %[[ACC4]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32> /// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32> -// CHECK: %[[V5R:.+]] = vector.reduction , %[[V5]] : vector<20xi32> into i32 +// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x3xi32> +// CHECK: %[[V5R:.+]] = vector.reduction , %[[V5]], %[[ACC5]] : vector<20xi32> into i32 // CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT]] -func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> { - %0 = vector.multi_reduction , %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> +func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> { + %0 = vector.multi_reduction , %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> return %0 : vector<2x5xf32> } @@ -77,12 +85,12 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vect // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] -func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { - %0 = vector.multi_reduction , %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> +func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.multi_reduction , %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> } // CHECK-LABEL: func @vector_multi_reduction_ordering -// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>) // CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32> // CHECK: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[C1:.+]] = arith.constant 1 : index @@ -94,28 +102,36 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2 // CHECK: %[[C7:.+]] = arith.constant 7 : index // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] -// CHECK: %[[RV0:.+]] = vector.reduction , %[[V0]] : vector<3xf32> into f32 +// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x4xf32> +// CHECK: %[[RV0:.+]] = vector.reduction , %[[V0]], %[[ACC0]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] -// CHECK: %[[RV1:.+]] = vector.reduction , %[[V1]] : vector<3xf32> into f32 +// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x4xf32> +// CHECK: %[[RV1:.+]] = vector.reduction , %[[V1]], %[[ACC1]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2] -// CHECK: %[[RV2:.+]] = vector.reduction , %[[V2]] : vector<3xf32> into f32 +// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x4xf32> +// CHECK: %[[RV2:.+]] = vector.reduction , %[[V2]], %[[ACC2]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3] -// CHECK: %[[RV3:.+]] = vector.reduction , %[[V3]] : vector<3xf32> into f32 +// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : vector<2x4xf32> +// CHECK: %[[RV3:.+]] = vector.reduction , %[[V3]], %[[ACC3]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32> // CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] -// CHECK: %[[RV4:.+]] = vector.reduction , %[[V4]] : vector<3xf32> into f32 +// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x4xf32> +// CHECK: %[[RV4:.+]] = vector.reduction , %[[V4]], %[[ACC4]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32> // CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] -// CHECK: %[[RV5:.+]] = vector.reduction , %[[V5]] : vector<3xf32> into f32 +// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x4xf32> +// CHECK: %[[RV5:.+]] = vector.reduction , %[[V5]], %[[ACC5]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32> // CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2] -// CHECK: %[[RV6:.+]] = vector.reduction , %[[V6]] : vector<3xf32> into f32 +// CHECK: %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x4xf32> +// CHECK: %[[RV6:.+]] = vector.reduction , %[[V6]], %[[ACC6]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32> // CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3] -// CHECK: %[[RV7:.+]] = vector.reduction , %[[V7]] : vector<3xf32> into f32 +// CHECK: %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : vector<2x4xf32> +// CHECK: %[[RV7:.+]] = vector.reduction , %[[V7]], %[[ACC7]] : vector<3xf32> into f32 // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32> // CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> // CHECK: return %[[RESHAPED_VEC]] diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir index bd06e48f4823..8a8bf86bfd38 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -1,101 +1,107 @@ // RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s -func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> +func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> // CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> // CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> -func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> +func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction_min -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[RV0:.+]] = arith.minf %[[V0]], %[[ACC]] : vector<2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[RV0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> // CHECK: %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> // CHECK: %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> -func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xf32> to vector<2xf32> +func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: func @vector_multi_reduction_max -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[RV0:.+]] = arith.maxf %[[V0]], %[[ACC]] : vector<2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[RV0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> // CHECK: %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> // CHECK: %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> -func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> +func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } // CHECK-LABEL: func @vector_multi_reduction_and -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> -// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> // CHECK: %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> // CHECK: %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32> // CHECK: return %[[RESULT_VEC]] : vector<2xi32> -func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> +func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } // CHECK-LABEL: func @vector_multi_reduction_or -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> -// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> // CHECK: %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> // CHECK: %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32> // CHECK: return %[[RESULT_VEC]] : vector<2xi32> -func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> { - %0 = vector.multi_reduction , %arg0 [1] : vector<2x4xi32> to vector<2xi32> +func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32> return %0 : vector<2xi32> } // CHECK-LABEL: func @vector_multi_reduction_xor -// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> -// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> // CHECK: %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> @@ -103,18 +109,20 @@ func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> { // CHECK: return %[[RESULT_VEC]] : vector<2xi32> -func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { - %0 = vector.multi_reduction , %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> +func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.multi_reduction , %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> } // CHECK-LABEL: func @vector_reduction_outer -// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32> +// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32> // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32> // CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32> +// CHECK: %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32> // CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32> +// CHECK: %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32> // CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32> -// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[V0]] : vector<6xi32> +// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32> // CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32> // CHECK: %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32> // CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32> @@ -157,15 +165,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> // This test is mainly to catch a bug that running // `InnerOuterDimReductionConversion` on this function results in an // infinite loop. So just check that some value is returned. -func.func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 { - %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<2xf32> to f32 +func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 { + %0 = vector.multi_reduction #vector.kind, %arg0, %acc [0] : vector<2xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @vector_reduction_1D // CHECK: return %{{.+}} -func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 { - %0 = vector.multi_reduction , %arg0 [0, 1] : vector<2x3xf32> to f32 +func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 { + %0 = vector.multi_reduction , %arg0, %acc [0, 1] : vector<2x3xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @vector_multi_reduction_to_scalar diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir index ade539e27822..f1587c2e2f3d 100644 --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -4,15 +4,15 @@ // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-LABEL: multidimreduction_contract -// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>) // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} -// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> // CHECK-NEXT: return %[[R]] : vector<8x16xf32> func.func @multidimreduction_contract( - %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> { + %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> { %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> - %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> + %1 = vector.multi_reduction , %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32> return %1 : vector<8x16xf32> } @@ -22,15 +22,15 @@ func.func @multidimreduction_contract( // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-LABEL: multidimreduction_contract_int -// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32> +// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>) // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} -// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> // CHECK-NEXT: return %[[R]] : vector<8x16xi32> func.func @multidimreduction_contract_int( - %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> { + %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>, %acc: vector<8x16xi32>) -> vector<8x16xi32> { %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32> - %1 = vector.multi_reduction , %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> + %1 = vector.multi_reduction , %0, %acc [1] : vector<8x32x16xi32> to vector<8x16xi32> return %1 : vector<8x16xi32> } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index d0d10887d6a2..db6a40d489d6 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -188,30 +188,28 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf // CHECK-LABEL: func @vector_fma // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> -func.func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> { - %0 = vector.multi_reduction #vector.kind, %v [1] : vector<4x6xf32> to vector<4xf32> +func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { + %0 = vector.multi_reduction #vector.kind, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> return %0 : vector<4xf32> } // CHECK-LABEL: func @vector_multi_reduction // CHECK: %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R0:.*]] = vector.multi_reduction , %[[E0]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: %[[R0:.*]] = vector.multi_reduction , %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32> // CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R1:.*]] = vector.multi_reduction , %[[E1]] [1] : vector<2x2xf32> to vector<2xf32> -// CHECK: %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32> +// CHECK: %[[R1:.*]] = vector.multi_reduction , %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32> // CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R2:.*]] = vector.multi_reduction , %5 [1] : vector<2x2xf32> to vector<2xf32> -// CHECK: %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32> +// CHECK: %[[R2:.*]] = vector.multi_reduction , %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32> // CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R3:.*]] = vector.multi_reduction , %[[E3]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: %[[R3:.*]] = vector.multi_reduction , %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32> // CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R4:.*]] = vector.multi_reduction , %[[E4]] [1] : vector<2x2xf32> to vector<2xf32> -// CHECK: %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32> +// CHECK: %[[R4:.*]] = vector.multi_reduction , %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32> // CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK: %[[R5:.*]] = vector.multi_reduction , %[[E5]] [1] : vector<2x2xf32> to vector<2xf32> -// CHECK: %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32> -// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> -// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: %[[R5:.*]] = vector.multi_reduction , %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32> +// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32>