forked from OSchip/llvm-project
[mlir][Vector] Add a canonicalization pattern for vector.contract + add
Differential Revision: https://reviews.llvm.org/D96701
This commit is contained in:
parent
5f58374bbe
commit
02d053ed2d
|
@ -246,6 +246,8 @@ def Vector_ContractionOp :
|
|||
return CombiningKind::ADD;
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Vector_ReductionOp :
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
@ -658,6 +659,66 @@ Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
|
|||
return shape;
|
||||
}
|
||||
|
||||
/// Return a fused vector::ContractionOp which represents a patterns such as:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %c0 = vector.constant 0: ...
|
||||
/// %c = vector.contract %a, %b, %c0: ...
|
||||
/// %e = add %c, %d: ...
|
||||
/// ```
|
||||
///
|
||||
/// by:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %e = vector.contract %a, %b, %d: ...
|
||||
/// ```
|
||||
///
|
||||
/// Return null if the canonicalization does not apply.
|
||||
// TODO: This should be a folding of Add into Contract in core but while they
|
||||
// live in different dialects, it is not possible without unnatural
|
||||
// dependencies.
|
||||
template <typename AddOpType>
|
||||
struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
|
||||
using OpRewritePattern<AddOpType>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AddOpType addOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto canonicalize = [&](Value maybeContraction,
|
||||
Value otherOperand) -> vector::ContractionOp {
|
||||
vector::ContractionOp contractionOp =
|
||||
dyn_cast_or_null<vector::ContractionOp>(
|
||||
maybeContraction.getDefiningOp());
|
||||
if (!contractionOp)
|
||||
return vector::ContractionOp();
|
||||
if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
|
||||
contractionOp.acc().getDefiningOp())) {
|
||||
if (maybeZero.value() ==
|
||||
rewriter.getZeroAttr(contractionOp.acc().getType())) {
|
||||
BlockAndValueMapping bvm;
|
||||
bvm.map(contractionOp.acc(), otherOperand);
|
||||
auto newContraction =
|
||||
cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
|
||||
rewriter.replaceOp(addOp, newContraction.getResult());
|
||||
return newContraction;
|
||||
}
|
||||
}
|
||||
return vector::ContractionOp();
|
||||
};
|
||||
|
||||
Value a = addOp->getOperand(0), b = addOp->getOperand(1);
|
||||
vector::ContractionOp contract = canonicalize(a, b);
|
||||
contract = contract ? contract : canonicalize(b, a);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void ContractionOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results
|
||||
.insert<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -710,3 +710,49 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
|
|||
memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#contraction_accesses0 = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
affine_map<(i, j, k) -> (i, j)>
|
||||
]
|
||||
#contraction_trait0 = {
|
||||
indexing_maps = #contraction_accesses0,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @contractions
|
||||
// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: vector<2x3xf32>
|
||||
// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: vector<3x4xf32>
|
||||
// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[A_I8:[0-9a-zA-Z]+]]: vector<2x3xi8>
|
||||
// CHECK-SAME: %[[B_I8:[0-9a-zA-Z]+]]: vector<3x4xi8>
|
||||
// CHECK-SAME: %[[C_I8:[0-9a-zA-Z]+]]: vector<2x4xi8>
|
||||
func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>,
|
||||
%a_i8: vector<2x3xi8>, %b_i8: vector<3x4xi8>, %c_i8: vector<2x4xi8>)
|
||||
-> (vector<2x4xf32>, vector<2x4xi8>)
|
||||
{
|
||||
// CHECK-NOT: constant
|
||||
%vf_0 = constant dense <0.0>: vector<2x4xf32>
|
||||
// CHECK-NOT: addf
|
||||
// CHECK: %[[D:.*]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]]
|
||||
%0 = vector.contract #contraction_trait0 %a, %b, %vf_0:
|
||||
vector<2x3xf32>, vector<3x4xf32> into vector<2x4xf32>
|
||||
// CHECK-NOT: addf
|
||||
%1 = addf %0, %c: vector<2x4xf32>
|
||||
|
||||
// CHECK-NOT: constant
|
||||
%vi8_0 = constant dense <0>: vector<2x4xi8>
|
||||
// CHECK-NOT: addi
|
||||
// CHECK: %[[D_I8:.*]] = vector.contract {{.*}} %[[A_I8]], %[[B_I8]], %[[C_I8]]
|
||||
%i8_0 = vector.contract #contraction_trait0 %a_i8, %b_i8, %vi8_0:
|
||||
vector<2x3xi8>, vector<3x4xi8> into vector<2x4xi8>
|
||||
// CHECK-NOT: addi
|
||||
%i8_1 = addi %i8_0, %c_i8: vector<2x4xi8>
|
||||
|
||||
// CHECK: return %[[D]], %[[D_I8]]
|
||||
return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue