[mlir][Vector] Add a canonicalization pattern for vector.contract + add

Differential Revision: https://reviews.llvm.org/D96701
This commit is contained in:
Nicolas Vasilache 2021-02-15 12:11:29 +00:00
parent 5f58374bbe
commit 02d053ed2d
3 changed files with 109 additions and 0 deletions

View File

@ -246,6 +246,8 @@ def Vector_ContractionOp :
return CombiningKind::ADD;
}
}];
let hasCanonicalizer = 1;
}
def Vector_ReductionOp :

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
@ -658,6 +659,66 @@ Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
return shape;
}
/// Return a fused vector::ContractionOp which represents a patterns such as:
///
/// ```mlir
/// %c0 = vector.constant 0: ...
/// %c = vector.contract %a, %b, %c0: ...
/// %e = add %c, %d: ...
/// ```
///
/// by:
///
/// ```mlir
/// %e = vector.contract %a, %b, %d: ...
/// ```
///
/// Return null if the canonicalization does not apply.
// TODO: This should be a folding of Add into Contract in core but while they
// live in different dialects, it is not possible without unnatural
// dependencies.
template <typename AddOpType>
struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
using OpRewritePattern<AddOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(AddOpType addOp,
PatternRewriter &rewriter) const override {
auto canonicalize = [&](Value maybeContraction,
Value otherOperand) -> vector::ContractionOp {
vector::ContractionOp contractionOp =
dyn_cast_or_null<vector::ContractionOp>(
maybeContraction.getDefiningOp());
if (!contractionOp)
return vector::ContractionOp();
if (auto maybeZero = dyn_cast_or_null<ConstantOp>(
contractionOp.acc().getDefiningOp())) {
if (maybeZero.value() ==
rewriter.getZeroAttr(contractionOp.acc().getType())) {
BlockAndValueMapping bvm;
bvm.map(contractionOp.acc(), otherOperand);
auto newContraction =
cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
rewriter.replaceOp(addOp, newContraction.getResult());
return newContraction;
}
}
return vector::ContractionOp();
};
Value a = addOp->getOperand(0), b = addOp->getOperand(1);
vector::ContractionOp contract = canonicalize(a, b);
contract = contract ? contract : canonicalize(b, a);
return success();
}
};
void ContractionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results
.insert<CanonicalizeContractAdd<AddIOp>, CanonicalizeContractAdd<AddFOp>>(
context);
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//

View File

@ -710,3 +710,49 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return
}
// -----
#contraction_accesses0 = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait0 = {
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "reduction"]
}
// CHECK-LABEL: func @contractions
// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: vector<2x3xf32>
// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: vector<3x4xf32>
// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: vector<2x4xf32>
// CHECK-SAME: %[[A_I8:[0-9a-zA-Z]+]]: vector<2x3xi8>
// CHECK-SAME: %[[B_I8:[0-9a-zA-Z]+]]: vector<3x4xi8>
// CHECK-SAME: %[[C_I8:[0-9a-zA-Z]+]]: vector<2x4xi8>
func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>,
%a_i8: vector<2x3xi8>, %b_i8: vector<3x4xi8>, %c_i8: vector<2x4xi8>)
-> (vector<2x4xf32>, vector<2x4xi8>)
{
// CHECK-NOT: constant
%vf_0 = constant dense <0.0>: vector<2x4xf32>
// CHECK-NOT: addf
// CHECK: %[[D:.*]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]]
%0 = vector.contract #contraction_trait0 %a, %b, %vf_0:
vector<2x3xf32>, vector<3x4xf32> into vector<2x4xf32>
// CHECK-NOT: addf
%1 = addf %0, %c: vector<2x4xf32>
// CHECK-NOT: constant
%vi8_0 = constant dense <0>: vector<2x4xi8>
// CHECK-NOT: addi
// CHECK: %[[D_I8:.*]] = vector.contract {{.*}} %[[A_I8]], %[[B_I8]], %[[C_I8]]
%i8_0 = vector.contract #contraction_trait0 %a_i8, %b_i8, %vi8_0:
vector<2x3xi8>, vector<3x4xi8> into vector<2x4xi8>
// CHECK-NOT: addi
%i8_1 = addi %i8_0, %c_i8: vector<2x4xi8>
// CHECK: return %[[D]], %[[D_I8]]
return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8>
}