forked from OSchip/llvm-project
[mlir][vector] Add accumulator operand to MultiDimReduce op
This allows vectorizing linalg reductions without changing the operation order. Therefore this produce a valid vectorization even if operations are not associative. Differential Revision: https://reviews.llvm.org/D129535
This commit is contained in:
parent
6b694d600a
commit
051b36ba28
|
@ -318,6 +318,7 @@ def Vector_ReductionOp :
|
|||
|
||||
def Vector_MultiDimReductionOp :
|
||||
Vector_Op<"multi_reduction", [NoSideEffect,
|
||||
AllTypesMatch<["dest", "acc"]>,
|
||||
PredOpTrait<"source operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
|
@ -325,6 +326,7 @@ def Vector_MultiDimReductionOp :
|
|||
["getShapeForUnroll"]>]>,
|
||||
Arguments<(ins Vector_CombiningKindAttr:$kind,
|
||||
AnyVector:$source,
|
||||
AnyType:$acc,
|
||||
I64ArrayAttr:$reduction_dims)>,
|
||||
Results<(outs AnyType:$dest)> {
|
||||
let summary = "Multi-dimensional reduction operation";
|
||||
|
@ -332,19 +334,20 @@ def Vector_MultiDimReductionOp :
|
|||
Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n)
|
||||
using the given operation (add/mul/min/max for int/fp and and/or/xor for
|
||||
int only).
|
||||
Takes an initial accumulator operand.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%1 = vector.multi_reduction <add>, %0 [1, 3] :
|
||||
%1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
|
||||
vector<4x8x16x32xf32> into vector<4x16xf32>
|
||||
%2 = vector.multi_reduction <add>, %1 [0, 1] :
|
||||
%2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
|
||||
vector<4x16xf32> into f32
|
||||
```
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "ArrayRef<bool>":$reductionMask,
|
||||
"CombiningKind":$kind)>
|
||||
OpBuilder<(ins "Value":$source, "Value":$acc,
|
||||
"ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getKindAttrStrName() { return "kind"; }
|
||||
|
@ -378,8 +381,9 @@ def Vector_MultiDimReductionOp :
|
|||
}
|
||||
}];
|
||||
let assemblyFormat =
|
||||
"$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)";
|
||||
"$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Vector_BroadcastOp :
|
||||
|
|
|
@ -174,13 +174,13 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
|
|||
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
|
||||
/// assumes that `reductionOp` has two operands and one of them is the reduction
|
||||
/// initial value.
|
||||
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
|
||||
Value valueToReduce,
|
||||
const SmallVector<bool> &reductionMask) {
|
||||
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
|
||||
Value valueToReduce, Value acc,
|
||||
const SmallVector<bool> &reductionMask) {
|
||||
auto maybeKind = getCombinerOpKind(reduceOp);
|
||||
assert(maybeKind && "Failed precondition: could not get reduction kind");
|
||||
return b.create<vector::MultiDimReductionOp>(
|
||||
reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
|
||||
reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
|
||||
}
|
||||
|
||||
static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
|
||||
|
@ -315,10 +315,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
|||
(outputType && reduceType.getShape() == outputType.getShape()))
|
||||
return nullptr;
|
||||
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
|
||||
Value reduce = buildMultiDimReduce(b, op, reduceVec, reductionMask);
|
||||
return b.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
/*operands=*/{reduce, outputVec}, reduce.getType(),
|
||||
op->getAttrs());
|
||||
return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
|
||||
}
|
||||
|
||||
/// Generic vectorization for a single operation `op`, given already vectorized
|
||||
|
|
|
@ -334,34 +334,14 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
|
|||
|
||||
void vector::MultiDimReductionOp::build(OpBuilder &builder,
|
||||
OperationState &result, Value source,
|
||||
ArrayRef<bool> reductionMask,
|
||||
Value acc, ArrayRef<bool> reductionMask,
|
||||
CombiningKind kind) {
|
||||
SmallVector<int64_t> reductionDims;
|
||||
for (const auto &en : llvm::enumerate(reductionMask))
|
||||
if (en.value())
|
||||
reductionDims.push_back(en.index());
|
||||
build(builder, result, kind, source, builder.getI64ArrayAttr(reductionDims));
|
||||
}
|
||||
|
||||
LogicalResult MultiDimReductionOp::inferReturnTypes(
|
||||
MLIRContext *, Optional<Location>, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
MultiDimReductionOp::Adaptor op(operands, attributes);
|
||||
auto vectorType = op.getSource().getType().cast<VectorType>();
|
||||
SmallVector<int64_t> targetShape;
|
||||
for (auto it : llvm::enumerate(vectorType.getShape()))
|
||||
if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) {
|
||||
return attr.cast<IntegerAttr>().getValue() == it.index();
|
||||
}))
|
||||
targetShape.push_back(it.value());
|
||||
// TODO: update to also allow 0-d vectors when available.
|
||||
if (targetShape.empty())
|
||||
inferredReturnTypes.push_back(vectorType.getElementType());
|
||||
else
|
||||
inferredReturnTypes.push_back(
|
||||
VectorType::get(targetShape, vectorType.getElementType()));
|
||||
return success();
|
||||
build(builder, result, kind, source, acc,
|
||||
builder.getI64ArrayAttr(reductionDims));
|
||||
}
|
||||
|
||||
OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
@ -375,6 +355,28 @@ Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
|
|||
return llvm::to_vector<4>(getSourceVectorType().getShape());
|
||||
}
|
||||
|
||||
LogicalResult MultiDimReductionOp::verify() {
|
||||
SmallVector<int64_t> targetShape;
|
||||
Type inferredReturnType;
|
||||
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
|
||||
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
|
||||
return attr.cast<IntegerAttr>().getValue() == it.index();
|
||||
}))
|
||||
targetShape.push_back(it.value());
|
||||
// TODO: update to also allow 0-d vectors when available.
|
||||
if (targetShape.empty())
|
||||
inferredReturnType = getSourceVectorType().getElementType();
|
||||
else
|
||||
inferredReturnType =
|
||||
VectorType::get(targetShape, getSourceVectorType().getElementType());
|
||||
if (getType() != inferredReturnType)
|
||||
return emitOpError() << "destination type " << getType()
|
||||
<< " is incompatible with source type "
|
||||
<< getSourceVectorType();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -87,8 +87,8 @@ public:
|
|||
reductionMask[i] = true;
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
||||
multiReductionOp, transposeOp.getResult(), reductionMask,
|
||||
multiReductionOp.getKind());
|
||||
multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
|
||||
reductionMask, multiReductionOp.getKind());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -188,11 +188,17 @@ public:
|
|||
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
|
||||
Value cast = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, castedType, multiReductionOp.getSource());
|
||||
|
||||
Value acc = multiReductionOp.getAcc();
|
||||
if (flattenedParallelDim) {
|
||||
auto accType = VectorType::get(
|
||||
{flattenedParallelDim},
|
||||
multiReductionOp.getSourceVectorType().getElementType());
|
||||
acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
|
||||
}
|
||||
// 5. Creates the flattened form of vector.multi_reduction with inner/outer
|
||||
// most dim as reduction.
|
||||
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
|
||||
loc, cast, mask, multiReductionOp.getKind());
|
||||
loc, cast, acc, mask, multiReductionOp.getKind());
|
||||
|
||||
// 6. If there are no parallel shapes, the result is a scalar.
|
||||
// TODO: support 0-d vectors when available.
|
||||
|
@ -238,10 +244,8 @@ struct TwoDimMultiReductionToElementWise
|
|||
if (!elementType.isIntOrIndexOrFloat())
|
||||
return failure();
|
||||
|
||||
Value result =
|
||||
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0)
|
||||
.getResult();
|
||||
for (int64_t i = 1; i < srcShape[0]; i++) {
|
||||
Value result = multiReductionOp.getAcc();
|
||||
for (int64_t i = 0; i < srcShape[0]; i++) {
|
||||
auto operand = rewriter.create<vector::ExtractOp>(
|
||||
loc, multiReductionOp.getSource(), i);
|
||||
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
|
||||
|
@ -277,8 +281,10 @@ struct TwoDimMultiReductionToReduction
|
|||
for (int i = 0; i < outerDim; ++i) {
|
||||
auto v = rewriter.create<vector::ExtractOp>(
|
||||
loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
|
||||
auto acc = rewriter.create<vector::ExtractOp>(
|
||||
loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
|
||||
auto reducedValue = rewriter.create<vector::ReductionOp>(
|
||||
loc, multiReductionOp.getKind(), v);
|
||||
loc, multiReductionOp.getKind(), v, acc);
|
||||
result = rewriter.create<vector::InsertElementOp>(
|
||||
loc, reducedValue, result,
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||
|
@ -309,6 +315,8 @@ struct OneDimMultiReductionToTwoDim
|
|||
auto srcShape = srcVectorType.getShape();
|
||||
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
|
||||
srcVectorType.getElementType());
|
||||
auto accType =
|
||||
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
|
||||
assert(!multiReductionOp.getDestType().isa<VectorType>() &&
|
||||
"multi_reduction with a single dimension expects a scalar result");
|
||||
|
||||
|
@ -319,8 +327,10 @@ struct OneDimMultiReductionToTwoDim
|
|||
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
|
||||
Value cast = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, castedType, multiReductionOp.getSource());
|
||||
Value castAcc = rewriter.create<vector::BroadcastOp>(
|
||||
loc, accType, multiReductionOp.getAcc());
|
||||
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
|
||||
loc, cast, mask, multiReductionOp.getKind());
|
||||
loc, cast, castAcc, mask, multiReductionOp.getKind());
|
||||
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
|
||||
ArrayRef<int64_t>{0});
|
||||
return success();
|
||||
|
|
|
@ -997,11 +997,8 @@ struct MultiReduceToContract
|
|||
}
|
||||
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
|
||||
/*symCount=*/0, exprs, reduceOp.getContext());
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
reduceOp.getLoc(), reduceOp.getDestType(),
|
||||
rewriter.getZeroAttr(reduceOp.getDestType()));
|
||||
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
|
||||
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
|
||||
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
|
||||
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
|
||||
rewriter.getStrArrayAttr(iteratorTypes));
|
||||
return success();
|
||||
|
|
|
@ -431,10 +431,11 @@ struct UnrollMultiReductionPattern
|
|||
SmallVector<int64_t, 4> offsets =
|
||||
getVectorOffset(originalSize, *targetShape, i);
|
||||
|
||||
SmallVector<Value> operands;
|
||||
SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
|
||||
|
||||
loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
|
||||
operands.push_back(slicedOperand);
|
||||
SmallVector<int64_t> dstShape;
|
||||
SmallVector<int64_t> destOffset;
|
||||
for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
|
||||
|
@ -443,17 +444,22 @@ struct UnrollMultiReductionPattern
|
|||
dstShape.push_back((*targetShape)[i]);
|
||||
}
|
||||
}
|
||||
Value acc;
|
||||
SmallVector<int64_t, 4> accStrides(destOffset.size(), 1);
|
||||
// If a version of the accumulator has already been computed, use it
|
||||
// otherwise extract the first version from the original operand.
|
||||
auto accIt = accCache.find(destOffset);
|
||||
if (accIt != accCache.end())
|
||||
acc = accIt->second;
|
||||
else
|
||||
acc = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
|
||||
operands.push_back(acc);
|
||||
auto targetType = VectorType::get(
|
||||
dstShape, reductionOp.getSourceVectorType().getElementType());
|
||||
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
|
||||
slicedOperand, targetType);
|
||||
operands, targetType);
|
||||
Value result = newOp->getResult(0);
|
||||
// Save the accumulated value until all the loops are unrolled since
|
||||
// reduction loop keeps updating the accumulator.
|
||||
auto accIt = accCache.find(destOffset);
|
||||
if (accIt != accCache.end())
|
||||
result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
|
||||
result, accIt->second);
|
||||
accCache[destOffset] = result;
|
||||
}
|
||||
// Assemble back the accumulator into a single vector.
|
||||
|
|
|
@ -10,9 +10,8 @@ func.func @vectorize_matmul(%arg0: tensor<24x12xf32>,
|
|||
// CHECK: %[[vA:.+]] = vector.transfer_read %[[A]]
|
||||
// CHECK: %[[vB:.+]] = vector.transfer_read %[[B]]
|
||||
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
|
||||
// CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vS]], %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vR]], %[[C]]
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
|
||||
func.return %0 : tensor<24x25xf32>
|
||||
}
|
||||
|
@ -67,9 +66,8 @@ func.func @vectorize_keep_pad(
|
|||
// CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]]
|
||||
// CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]]
|
||||
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
|
||||
// CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vS]], %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vR]], %[[C]]
|
||||
%8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
%9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
|
||||
return %9 : tensor<24x25xf32>
|
||||
|
@ -127,9 +125,8 @@ func.func @vectorize_pad(
|
|||
tensor.yield %cst : f32
|
||||
} : tensor<?x5xf32> to tensor<7x5xf32>
|
||||
// CHECK: %[[vC:.+]] = vector.transfer_read %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]]
|
||||
// CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vS]], %[[C]]
|
||||
// CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]], %[[vC]]
|
||||
// CHECK: vector.transfer_write %[[vR]], %[[C]]
|
||||
%8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
|
||||
%9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
|
||||
return %9 : tensor<24x25xf32>
|
||||
|
|
|
@ -6,8 +6,7 @@
|
|||
func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
|
||||
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [0] : vector<1584xf32> to f32
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : f32
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [0] : vector<1584xf32> to f32
|
||||
linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
|
||||
outs(%C: memref<f32>)
|
||||
return
|
||||
|
@ -19,8 +18,7 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre
|
|||
func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
|
||||
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [1] : vector<1584x1584xf32> to vector<1584xf32>
|
||||
linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
|
||||
outs(%C: memref<1584xf32>)
|
||||
return
|
||||
|
@ -31,8 +29,7 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %
|
|||
// CHECK-LABEL: contraction_matmul
|
||||
func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<1584x1584x1584xf32> to vector<1584x1584xf32>
|
||||
linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
|
||||
outs(%C: memref<1584x1584xf32>)
|
||||
return
|
||||
|
@ -43,8 +40,7 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3
|
|||
// CHECK-LABEL: contraction_batch_matmul
|
||||
func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [3] : vector<1584x1584x1584x1584xf32> to vector<1584x1584x1584xf32>
|
||||
linalg.batch_matmul
|
||||
ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
|
||||
outs(%C: memref<1584x1584x1584xf32>)
|
||||
|
@ -69,10 +65,9 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
|||
%C: memref<8x32xf32>) {
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
|
||||
linalg.generic #matmul_trait
|
||||
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
|
||||
|
@ -103,10 +98,9 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
|||
%C: memref<32x8xf32>) {
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %[[R]], %{{.*}} : vector<8x32xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32>
|
||||
linalg.generic #matmul_transpose_out_trait
|
||||
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
|
||||
|
@ -157,11 +151,9 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
|
|||
%C: memref<8x32xi32>) {
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32>
|
||||
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
|
||||
// CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
|
||||
// CHECK: arith.addi %[[R]], %{{.*}} : vector<8x32xi32>
|
||||
|
||||
// CHECK: vector.multi_reduction <add>, %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
|
||||
linalg.generic #matmul_trait
|
||||
ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
|
||||
|
@ -180,8 +172,7 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
|
|||
func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
||||
%C: memref<8x32xf32>) {
|
||||
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8x32xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32>
|
||||
linalg.matmul
|
||||
ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
|
||||
outs(%C: memref<8x32xf32>)
|
||||
|
@ -560,9 +551,8 @@ func.func @matmul_tensors(
|
|||
// linalg matmul lowers gets expanded to a 3D reduction, canonicalization later
|
||||
// convert it to a 2D contract.
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
|
||||
// CHECK: %[[ADD:.*]] = arith.addf %[[R]], %[[V2]] : vector<8x12xf32>
|
||||
// CHECK: %[[W:.*]] = vector.transfer_write %[[ADD]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
|
||||
// CHECK: %[[R:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32>
|
||||
// CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32>
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
|
||||
outs(%arg2: tensor<8x12xf32>)
|
||||
-> tensor<8x12xf32>
|
||||
|
@ -801,8 +791,7 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
|
|||
// CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<4x16x8xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
|
||||
// CHECK: addf {{.*}} : vector<4x16xf32>
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
|
||||
// CHECK: return {{.*}} : tensor<4x16xf32>
|
||||
%0 = linalg.generic {
|
||||
|
@ -836,8 +825,7 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output
|
|||
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: addf {{.*}} : vector<2x3x4x5xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
// CHECK: addf {{.*}} : vector<2x5xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}}, %{{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
|
||||
// CHECK: return {{.*}} : tensor<5x2xf32>
|
||||
%0 = linalg.generic {
|
||||
|
@ -865,8 +853,7 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
|||
// CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
|
||||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction <maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: maxf %[[R]], %[[CMINF]] : vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%ident = arith.constant -3.40282e+38 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
|
@ -890,8 +877,7 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
|||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: %[[R:.+]] = vector.multi_reduction <minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%maxf32 = arith.constant 3.40282e+38 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
|
@ -914,7 +900,7 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
|||
// CHECK: linalg.init_tensor [4] : tensor<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%ident = arith.constant 1.0 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
|
@ -937,7 +923,7 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
|||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction <or>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant false
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
|
@ -960,7 +946,7 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
|||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction <and>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant true
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
|
@ -983,7 +969,7 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
|
|||
// CHECK: linalg.init_tensor [4] : tensor<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
|
||||
// CHECK: vector.multi_reduction <xor>, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
|
||||
%ident = arith.constant false
|
||||
%init = linalg.init_tensor [4] : tensor<4xi1>
|
||||
|
@ -1035,8 +1021,7 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>
|
|||
// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32>
|
||||
// CHECK: subf {{.*}} : vector<4x4xf32>
|
||||
// CHECK: math.exp {{.*}} : vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}} : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: addf {{.*}} : vector<4xf32>
|
||||
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<4xf32>
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
|
@ -1075,10 +1060,9 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
|
|||
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
|
||||
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
|
||||
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
|
||||
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]] [0]
|
||||
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
|
||||
// CHECK-SAME: : vector<32xf32> to f32
|
||||
// CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
|
||||
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
|
||||
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
|
||||
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
|
||||
// CHECK-SAME: : vector<f32>, tensor<f32>
|
||||
%2 = linalg.generic {
|
||||
|
|
|
@ -1281,9 +1281,9 @@ func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>,
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
|
||||
// CHECK-SAME: %[[v:.*]]: vector<2xf32>
|
||||
func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [] : vector<2xf32> to vector<2xf32>
|
||||
// CHECK-SAME: %[[v:.*]]: vector<2xf32>,
|
||||
func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [] : vector<2xf32> to vector<2xf32>
|
||||
|
||||
// CHECK: return %[[v]] : vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
|
|
|
@ -1138,9 +1138,9 @@ func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
|
|||
|
||||
// -----
|
||||
|
||||
func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>) -> f32 {
|
||||
// expected-error@+1 {{'vector.multi_reduction' op inferred type(s) 'vector<4xf32>' are incompatible with return type(s) of operation 'vector<16xf32>'}}
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<4x16xf32> to vector<16xf32>
|
||||
func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf32>) -> f32 {
|
||||
// expected-error@+1 {{'vector.multi_reduction' op destination type 'vector<16xf32>' is incompatible with source type 'vector<4x16xf32>'}}
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<4x16xf32> to vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -705,10 +705,13 @@ func.func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @multi_reduction
|
||||
func.func @multi_reduction(%0: vector<4x8x16x32xf32>) -> f32 {
|
||||
%1 = vector.multi_reduction <add>, %0 [1, 3] :
|
||||
func.func @multi_reduction(%0: vector<4x8x16x32xf32>, %acc0: vector<4x16xf32>,
|
||||
%acc1: f32) -> f32 {
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [1, 3] : vector<4x8x16x32xf32> to vector<4x16xf32>
|
||||
%1 = vector.multi_reduction <add>, %0, %acc0 [1, 3] :
|
||||
vector<4x8x16x32xf32> to vector<4x16xf32>
|
||||
%2 = vector.multi_reduction <add>, %1 [0, 1] :
|
||||
// CHECK: vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x16xf32> to f32
|
||||
%2 = vector.multi_reduction <add>, %1, %acc1 [0, 1] :
|
||||
vector<4x16xf32> to f32
|
||||
return %2 : f32
|
||||
}
|
||||
|
|
|
@ -1,40 +1,42 @@
|
|||
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
|
||||
|
||||
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
|
||||
// CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<4xf32> into f32
|
||||
// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<4xf32> into f32
|
||||
// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]]
|
||||
|
||||
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [0, 1] : vector<2x4xf32> to f32
|
||||
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
|
||||
// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
|
||||
// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]] : vector<8xf32> into f32
|
||||
// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
|
||||
// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
|
||||
// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
|
||||
func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
return %0 : vector<2x3xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_reduction_inner
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
|
||||
// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
|
@ -44,29 +46,35 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
|
|||
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
|
||||
// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
|
||||
// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x3xi32>
|
||||
// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
|
||||
// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x3xi32>
|
||||
// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
|
||||
// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x3xi32>
|
||||
// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
|
||||
// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x3xi32>
|
||||
// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32>
|
||||
// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
|
||||
// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x3xi32>
|
||||
// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32>
|
||||
/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
|
||||
// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]] : vector<20xi32> into i32
|
||||
// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x3xi32>
|
||||
// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32>
|
||||
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
|
||||
func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
return %0 : vector<2x5xf32>
|
||||
}
|
||||
|
||||
|
@ -77,12 +85,12 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vect
|
|||
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
|
||||
func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
|
||||
return %0 : vector<2x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_ordering
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>)
|
||||
// CHECK: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[C1:.+]] = arith.constant 1 : index
|
||||
|
@ -94,28 +102,36 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2
|
|||
// CHECK: %[[C7:.+]] = arith.constant 7 : index
|
||||
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : vector<2x4xf32>
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : vector<2x4xf32>
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
|
||||
// CHECK: %[[RV2:.+]] = vector.reduction <mul>, %[[V2]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : vector<2x4xf32>
|
||||
// CHECK: %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
|
||||
// CHECK: %[[RV3:.+]] = vector.reduction <mul>, %[[V3]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : vector<2x4xf32>
|
||||
// CHECK: %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
|
||||
// CHECK: %[[RV4:.+]] = vector.reduction <mul>, %[[V4]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : vector<2x4xf32>
|
||||
// CHECK: %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
|
||||
// CHECK: %[[RV5:.+]] = vector.reduction <mul>, %[[V5]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : vector<2x4xf32>
|
||||
// CHECK: %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
|
||||
// CHECK: %[[RV6:.+]] = vector.reduction <mul>, %[[V6]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : vector<2x4xf32>
|
||||
// CHECK: %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32>
|
||||
// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
|
||||
// CHECK: %[[RV7:.+]] = vector.reduction <mul>, %[[V7]] : vector<3xf32> into f32
|
||||
// CHECK: %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : vector<2x4xf32>
|
||||
// CHECK: %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
|
||||
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
|
||||
// CHECK: return %[[RESHAPED_VEC]]
|
||||
|
|
|
@ -1,101 +1,107 @@
|
|||
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
|
||||
|
||||
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||
// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||
// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <minf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_min
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||
// CHECK: %[[RV0:.+]] = arith.minf %[[V0]], %[[ACC]] : vector<2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[RV0]] : vector<2xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||
// CHECK: %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction <maxf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_max
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||
// CHECK: %[[RV0:.+]] = arith.maxf %[[V0]], %[[ACC]] : vector<2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32>
|
||||
// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[RV0]] : vector<2xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||
// CHECK: %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||
|
||||
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_and
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||
// CHECK: %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[V0]] : vector<2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||
// CHECK: %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||
|
||||
func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_or
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||
// CHECK: %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[V0]] : vector<2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||
// CHECK: %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||
// CHECK: %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32>
|
||||
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||
|
||||
func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||
func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
|
||||
%0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
|
||||
return %0 : vector<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_xor
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||
// CHECK: %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[V0]] : vector<2xi32>
|
||||
// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||
// CHECK: %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||
|
@ -103,18 +109,20 @@ func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
|||
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||
|
||||
|
||||
func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
return %0 : vector<2x3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_reduction_outer
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
|
||||
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
|
||||
// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
|
||||
// CHECK: %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32>
|
||||
// CHECK: %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32>
|
||||
// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[V0]] : vector<6xi32>
|
||||
// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32>
|
||||
// CHECK: %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32>
|
||||
|
@ -157,15 +165,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32>
|
|||
// This test is mainly to catch a bug that running
|
||||
// `InnerOuterDimReductionConversion` on this function results in an
|
||||
// infinite loop. So just check that some value is returned.
|
||||
func.func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 {
|
||||
%0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [0] : vector<2xf32> to f32
|
||||
func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
|
||||
%0 = vector.multi_reduction #vector.kind<maxf>, %arg0, %acc [0] : vector<2xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @vector_reduction_1D
|
||||
// CHECK: return %{{.+}}
|
||||
|
||||
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 {
|
||||
%0 = vector.multi_reduction <add>, %arg0 [0, 1] : vector<2x3xf32> to f32
|
||||
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
|
||||
|
|
|
@ -4,15 +4,15 @@
|
|||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
|
||||
func.func @multidimreduction_contract(
|
||||
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
|
||||
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
%1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
|
@ -22,15 +22,15 @@ func.func @multidimreduction_contract(
|
|||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract_int
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
|
||||
func.func @multidimreduction_contract_int(
|
||||
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
|
||||
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>, %acc: vector<8x16xi32>) -> vector<8x16xi32> {
|
||||
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
|
||||
%1 = vector.multi_reduction <add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
|
||||
%1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xi32> to vector<8x16xi32>
|
||||
return %1 : vector<8x16xi32>
|
||||
}
|
||||
|
||||
|
|
|
@ -188,30 +188,28 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
|
|||
// CHECK-LABEL: func @vector_fma
|
||||
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
|
||||
|
||||
func.func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %v [1] : vector<4x6xf32> to vector<4xf32>
|
||||
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction
|
||||
// CHECK: %[[V0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[ACC0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
|
||||
// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[E0]], %[[ACC0]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[A0:.*]] = arith.addf %[[R1]], %[[R0]] : vector<2xf32>
|
||||
// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[E1]], %[[R0]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R2:.*]] = vector.multi_reduction <add>, %5 [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[A1:.*]] = arith.addf %[[R2]], %[[A0]] : vector<2xf32>
|
||||
// CHECK: %[[R2:.*]] = vector.multi_reduction <add>, %[[E2]], %[[R1]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[ACC1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
|
||||
// CHECK: %[[R3:.*]] = vector.multi_reduction <add>, %[[E3]], %[[ACC1]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[A2:.*]] = arith.addf %[[R4]], %[[R3]] : vector<2xf32>
|
||||
// CHECK: %[[R4:.*]] = vector.multi_reduction <add>, %[[E4]], %[[R3]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[A3:.*]] = arith.addf %[[R5]], %[[A2]] : vector<2xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: %[[R5:.*]] = vector.multi_reduction <add>, %[[E5]], %[[R4]] [1] : vector<2x2xf32> to vector<2xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[R2]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: return %[[V2]] : vector<4xf32>
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue