forked from OSchip/llvm-project
[mlir] Add support for moving reductions to outer most dimensions in vector.multi_reduction
The approach for handling reductions in the outer most dimension follows that for inner most dimensions, outlined below First, transpose to move reduction dims, if needed Convert reduction from n-d to 2-d canonical form Then, for outer reductions, we emit the appropriate op (add/mul/min/max/or/and/xor) and combine the results. Differential Revision: https://reviews.llvm.org/D107675
This commit is contained in:
parent
8e9ffa1dc6
commit
e33f301ec2
|
@ -81,7 +81,8 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||||
|
|
||||||
// Collect a set of patterns to convert vector.multi_reduction op into
|
// Collect a set of patterns to convert vector.multi_reduction op into
|
||||||
// a sequence of vector.reduction ops.
|
// a sequence of vector.reduction ops.
|
||||||
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
|
void populateVectorMultiReductionLoweringPatterns(
|
||||||
|
RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
|
||||||
|
|
||||||
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
|
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
|
||||||
/// chain.
|
/// chain.
|
||||||
|
|
|
@ -3490,12 +3490,18 @@ private:
|
||||||
const bool enableIndexOptimizations;
|
const bool enableIndexOptimizations;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts vector.multi_reduction into inner-most reduction form by inserting
|
// Converts vector.multi_reduction into inner-most/outer-most reduction form
|
||||||
// vector.transpose
|
// by using vector.tranpose
|
||||||
struct InnerDimReductionConversion
|
class InnerOuterDimReductionConversion
|
||||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||||
|
public:
|
||||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
explicit InnerOuterDimReductionConversion(MLIRContext *context,
|
||||||
|
bool useInnerDimsForReduction)
|
||||||
|
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
|
||||||
|
useInnerDimsForReduction(useInnerDimsForReduction) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto src = multiReductionOp.source();
|
auto src = multiReductionOp.source();
|
||||||
|
@ -3516,92 +3522,203 @@ struct InnerDimReductionConversion
|
||||||
parallelDims.push_back(i);
|
parallelDims.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add transpose only if inner-most dimensions are not reductions
|
// Add transpose only if inner-most/outer-most dimensions are not parallel
|
||||||
if (parallelDims ==
|
if (useInnerDimsForReduction &&
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))
|
(parallelDims ==
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!useInnerDimsForReduction &&
|
||||||
|
(parallelDims !=
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t, 4> indices;
|
SmallVector<int64_t, 4> indices;
|
||||||
indices.append(parallelDims.begin(), parallelDims.end());
|
if (useInnerDimsForReduction) {
|
||||||
indices.append(reductionDims.begin(), reductionDims.end());
|
indices.append(parallelDims.begin(), parallelDims.end());
|
||||||
|
indices.append(reductionDims.begin(), reductionDims.end());
|
||||||
|
} else {
|
||||||
|
indices.append(reductionDims.begin(), reductionDims.end());
|
||||||
|
indices.append(parallelDims.begin(), parallelDims.end());
|
||||||
|
}
|
||||||
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
||||||
SmallVector<bool> reductionMask(srcRank, false);
|
SmallVector<bool> reductionMask(srcRank, false);
|
||||||
for (int i = 0; i < reductionSize; ++i) {
|
for (int i = 0; i < reductionSize; ++i) {
|
||||||
reductionMask[srcRank - i - 1] = true;
|
if (useInnerDimsForReduction)
|
||||||
|
reductionMask[srcRank - i - 1] = true;
|
||||||
|
else
|
||||||
|
reductionMask[i] = true;
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
||||||
multiReductionOp, transposeOp.result(), reductionMask,
|
multiReductionOp, transposeOp.result(), reductionMask,
|
||||||
multiReductionOp.kind());
|
multiReductionOp.kind());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const bool useInnerDimsForReduction;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
|
// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
|
||||||
// dimensions are inner most.
|
// dimensions are either inner most or outer most.
|
||||||
struct ReduceMultiDimReductionRank
|
class ReduceMultiDimReductionRank
|
||||||
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
explicit ReduceMultiDimReductionRank(MLIRContext *context,
|
||||||
|
bool useInnerDimsForReduction)
|
||||||
|
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
|
||||||
|
useInnerDimsForReduction(useInnerDimsForReduction) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||||
|
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
|
||||||
|
auto loc = multiReductionOp.getLoc();
|
||||||
|
if (srcRank == 2)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Separate reduction and parallel dims
|
||||||
|
auto reductionDimsRange =
|
||||||
|
multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
|
||||||
|
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
|
||||||
|
reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
|
||||||
|
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
|
||||||
|
reductionDims.end());
|
||||||
|
SmallVector<int64_t, 4> parallelDims, parallelShapes;
|
||||||
|
int canonicalReductionDim = 1;
|
||||||
|
int canonicalParallelDim = 1;
|
||||||
|
for (int64_t i = 0; i < srcRank; i++) {
|
||||||
|
if (!reductionDimsSet.contains(i)) {
|
||||||
|
parallelDims.push_back(i);
|
||||||
|
parallelShapes.push_back(srcShape[i]);
|
||||||
|
canonicalParallelDim *= srcShape[i];
|
||||||
|
} else {
|
||||||
|
canonicalReductionDim *= srcShape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fail if reduction dims are not either inner-most or outer-most
|
||||||
|
if (useInnerDimsForReduction &&
|
||||||
|
(parallelDims !=
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!useInnerDimsForReduction &&
|
||||||
|
(parallelDims ==
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Creates shape cast for the inputs n_d -> 2d
|
||||||
|
int64_t outerDim =
|
||||||
|
useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim;
|
||||||
|
int64_t innerDim =
|
||||||
|
useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim;
|
||||||
|
|
||||||
|
auto castedType = VectorType::get(
|
||||||
|
ArrayRef<int64_t>{outerDim, innerDim},
|
||||||
|
multiReductionOp.getSourceVectorType().getElementType());
|
||||||
|
auto castedOp = rewriter.create<vector::ShapeCastOp>(
|
||||||
|
loc, castedType, multiReductionOp.source());
|
||||||
|
|
||||||
|
// Creates the canonical form of 2d vector.multi_reduction with inner/outer
|
||||||
|
// most dim as reduction.
|
||||||
|
SmallVector<bool, 2> mask{!useInnerDimsForReduction,
|
||||||
|
useInnerDimsForReduction};
|
||||||
|
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
|
||||||
|
loc, castedOp.result(), mask, multiReductionOp.kind());
|
||||||
|
|
||||||
|
// Creates shape cast for the output 2d -> nd
|
||||||
|
VectorType outputCastedType = VectorType::get(
|
||||||
|
parallelShapes,
|
||||||
|
multiReductionOp.getSourceVectorType().getElementType());
|
||||||
|
Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
|
||||||
|
loc, outputCastedType, newOp.dest());
|
||||||
|
|
||||||
|
rewriter.replaceOp(multiReductionOp, castedOutputOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const bool useInnerDimsForReduction;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Unrolls vector.multi_reduction with outermost reductions
|
||||||
|
// and combines results
|
||||||
|
struct UnrollOuterMultiReduction
|
||||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||||
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
|
if (srcRank != 2)
|
||||||
if (srcRank == 2)
|
return failure();
|
||||||
|
|
||||||
|
if (multiReductionOp.getReductionMask()[1] ||
|
||||||
|
!multiReductionOp.getReductionMask()[0])
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loc = multiReductionOp.getLoc();
|
auto loc = multiReductionOp.getLoc();
|
||||||
auto reductionDims = llvm::to_vector<4>(
|
ArrayRef<int64_t> srcShape =
|
||||||
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
|
multiReductionOp.getSourceVectorType().getShape();
|
||||||
[](Attribute attr) -> int64_t {
|
|
||||||
return attr.cast<IntegerAttr>().getInt();
|
|
||||||
}));
|
|
||||||
llvm::sort(reductionDims);
|
|
||||||
|
|
||||||
// Fails if not inner most reduction.
|
Type elementType = multiReductionOp.getDestVectorType().getElementType();
|
||||||
int64_t reductionSize = reductionDims.size();
|
if (!elementType.isIntOrIndexOrFloat())
|
||||||
bool innerMostReduction = true;
|
|
||||||
for (int i = 0; i < reductionSize; ++i) {
|
|
||||||
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
|
|
||||||
innerMostReduction = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!innerMostReduction)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Extracts 2d rank reduction shape.
|
Value condition;
|
||||||
int innerDims = 1;
|
Value result =
|
||||||
int outterDims = 1;
|
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
|
||||||
SmallVector<int64_t> innerDimsShape;
|
.getResult();
|
||||||
for (int i = 0; i < srcRank; ++i) {
|
for (int64_t i = 1; i < srcShape[0]; i++) {
|
||||||
if (i < (srcRank - reductionSize)) {
|
auto operand =
|
||||||
innerDims *= srcShape[i];
|
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
|
||||||
innerDimsShape.push_back(srcShape[i]);
|
switch (multiReductionOp.kind()) {
|
||||||
} else {
|
case vector::CombiningKind::ADD:
|
||||||
outterDims *= srcShape[i];
|
if (elementType.isIntOrIndex())
|
||||||
|
result = rewriter.create<AddIOp>(loc, operand, result);
|
||||||
|
else
|
||||||
|
result = rewriter.create<AddFOp>(loc, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::MUL:
|
||||||
|
if (elementType.isIntOrIndex())
|
||||||
|
result = rewriter.create<MulIOp>(loc, operand, result);
|
||||||
|
else
|
||||||
|
result = rewriter.create<MulFOp>(loc, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::MIN:
|
||||||
|
if (elementType.isIntOrIndex())
|
||||||
|
condition =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, operand, result);
|
||||||
|
else
|
||||||
|
condition =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, result);
|
||||||
|
result = rewriter.create<SelectOp>(loc, condition, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::MAX:
|
||||||
|
if (elementType.isIntOrIndex())
|
||||||
|
condition =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, operand, result);
|
||||||
|
else
|
||||||
|
condition =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, operand, result);
|
||||||
|
result = rewriter.create<SelectOp>(loc, condition, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::AND:
|
||||||
|
result = rewriter.create<AndOp>(loc, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::OR:
|
||||||
|
result = rewriter.create<OrOp>(loc, operand, result);
|
||||||
|
break;
|
||||||
|
case vector::CombiningKind::XOR:
|
||||||
|
result = rewriter.create<XOrOp>(loc, operand, result);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates shape cast for the inputs n_d -> 2d
|
rewriter.replaceOp(multiReductionOp, result);
|
||||||
auto castedType = VectorType::get(
|
|
||||||
{innerDims, outterDims},
|
|
||||||
multiReductionOp.getSourceVectorType().getElementType());
|
|
||||||
auto castedOp = rewriter.create<vector::ShapeCastOp>(
|
|
||||||
loc, castedType, multiReductionOp.source());
|
|
||||||
|
|
||||||
// Creates the canonical form of 2d vector.multi_reduction with inner most
|
|
||||||
// dim as reduction.
|
|
||||||
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
|
|
||||||
loc, castedOp.result(), ArrayRef<bool>{false, true},
|
|
||||||
multiReductionOp.kind());
|
|
||||||
|
|
||||||
// Creates shape cast for the output 2d -> nd
|
|
||||||
auto outputCastedType = VectorType::get(
|
|
||||||
innerDimsShape,
|
|
||||||
multiReductionOp.getSourceVectorType().getElementType());
|
|
||||||
Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
|
|
||||||
loc, outputCastedType, newOp.dest());
|
|
||||||
|
|
||||||
rewriter.replaceOp(multiReductionOp, castedOutputOp);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -3747,9 +3864,13 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns, bool useInnerDimsForReduction) {
|
||||||
patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
|
patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
|
||||||
TwoDimMultiReductionToReduction>(patterns.getContext());
|
patterns.getContext(), useInnerDimsForReduction);
|
||||||
|
if (useInnerDimsForReduction)
|
||||||
|
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
|
||||||
|
else
|
||||||
|
patterns.add<UnrollOuterMultiReduction>(patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::vector::populateVectorUnrollPatterns(
|
void mlir::vector::populateVectorUnrollPatterns(
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
|
||||||
|
|
||||||
|
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||||
|
return %0 : vector<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[RV01:.+]] = mulf %[[V1]], %[[V0]] : vector<2xf32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[RV012:.+]] = mulf %[[V2]], %[[RV01]] : vector<2xf32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = mulf %[[V3]], %[[RV012]] : vector<2xf32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||||
|
|
||||||
|
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<min>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||||
|
return %0 : vector<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction_min
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||||
|
|
||||||
|
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<max>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||||
|
return %0 : vector<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction_max
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
|
||||||
|
// CHECK: %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
|
||||||
|
|
||||||
|
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<and>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||||
|
return %0 : vector<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction_and
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV01:.+]] = and %[[V1]], %[[V0]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV012:.+]] = and %[[V2]], %[[RV01]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = and %[[V3]], %[[RV012]] : vector<2xi32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||||
|
|
||||||
|
func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<or>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||||
|
return %0 : vector<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction_or
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV01:.+]] = or %[[V1]], %[[V0]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV012:.+]] = or %[[V2]], %[[RV01]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = or %[[V3]], %[[RV012]] : vector<2xi32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||||
|
|
||||||
|
func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<xor>, %arg0 [1] : vector<2x4xi32> to vector<2xi32>
|
||||||
|
return %0 : vector<2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_multi_reduction_xor
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV01:.+]] = xor %[[V1]], %[[V0]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RV012:.+]] = xor %[[V2]], %[[RV01]] : vector<2xi32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = xor %[[V3]], %[[RV012]] : vector<2xi32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
|
||||||
|
|
||||||
|
|
||||||
|
func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||||
|
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||||
|
return %0 : vector<2x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_reduction_outer
|
||||||
|
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
|
||||||
|
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
|
||||||
|
// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
|
||||||
|
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R0:.+]] = addi %[[V1]], %[[V0]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R1:.+]] = addi %[[V2]], %[[R0]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R2:.+]] = addi %[[V3]], %[[R1]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R3:.+]] = addi %[[V4]], %[[R2]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R4:.+]] = addi %[[V5]], %[[R3]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R5:.+]] = addi %[[V6]], %[[R4]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R6:.+]] = addi %[[V7]], %[[R5]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R7:.+]] = addi %[[V8]], %[[R6]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R8:.+]] = addi %[[V9]], %[[R7]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R9:.+]] = addi %[[V10]], %[[R8]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R10:.+]] = addi %[[V11]], %[[R9]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R11:.+]] = addi %[[V12]], %[[R10]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R12:.+]] = addi %[[V13]], %[[R11]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R13:.+]] = addi %[[V14]], %[[R12]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R14:.+]] = addi %[[V15]], %[[R13]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R15:.+]] = addi %[[V16]], %[[R14]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R16:.+]] = addi %[[V17]], %[[R15]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R17:.+]] = addi %[[V18]], %[[R16]] : vector<6xi32>
|
||||||
|
// CHECK: %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<20x6xi32>
|
||||||
|
// CHECK: %[[R18:.+]] = addi %[[V19]], %[[R17]] : vector<6xi32>
|
||||||
|
// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
|
||||||
|
// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32>
|
|
@ -444,6 +444,9 @@ struct TestVectorTransferLoweringPatterns
|
||||||
struct TestVectorMultiReductionLoweringPatterns
|
struct TestVectorMultiReductionLoweringPatterns
|
||||||
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
|
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
|
||||||
FunctionPass> {
|
FunctionPass> {
|
||||||
|
TestVectorMultiReductionLoweringPatterns() = default;
|
||||||
|
TestVectorMultiReductionLoweringPatterns(
|
||||||
|
const TestVectorMultiReductionLoweringPatterns &pass) {}
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<memref::MemRefDialect>();
|
registry.insert<memref::MemRefDialect>();
|
||||||
}
|
}
|
||||||
|
@ -454,9 +457,13 @@ struct TestVectorMultiReductionLoweringPatterns
|
||||||
return "Test conversion patterns to lower vector.multi_reduction to other "
|
return "Test conversion patterns to lower vector.multi_reduction to other "
|
||||||
"vector ops";
|
"vector ops";
|
||||||
}
|
}
|
||||||
|
Option<bool> useOuterReductions{
|
||||||
|
*this, "use-outer-reductions",
|
||||||
|
llvm::cl::desc("Move reductions to outer most dimensions"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
populateVectorMultiReductionLoweringPatterns(patterns);
|
populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue