forked from OSchip/llvm-project
[mlir][vector] Add support for unrolling vector.fma
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96781
This commit is contained in:
parent
85f025e5b3
commit
d8c7f442ea
|
@ -583,8 +583,10 @@ def Vector_ExtractMapOp :
|
|||
}
|
||||
|
||||
def Vector_FMAOp :
|
||||
Op<Vector_Dialect, "fma", [NoSideEffect,
|
||||
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
|
||||
Op<Vector_Dialect, "fma", [
|
||||
NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
|
||||
Results<(outs AnyVector:$result)> {
|
||||
let summary = "vector fused multiply-add";
|
||||
|
|
|
@ -1258,6 +1258,14 @@ AffineMap calculateImplicitMap(MapOp op) {
|
|||
|
||||
AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FmaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2456,8 +2464,7 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
|
|||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
|
||||
auto s = getVectorType().getShape();
|
||||
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
void TransferReadOp::getEffects(
|
||||
|
|
|
@ -73,3 +73,10 @@ func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
|
|||
// CHECK: vector.contract {
|
||||
// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
|
||||
// CHECK: return
|
||||
|
||||
func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> {
|
||||
%0 = vector.fma %a, %b, %c: vector<4x4xf32>
|
||||
return %0 : vector<4x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_fma
|
||||
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
|
||||
|
|
|
@ -151,8 +151,9 @@ struct TestVectorUnrollingPatterns
|
|||
patterns.insert<UnrollVectorPattern>(
|
||||
ctx, UnrollVectorOptions()
|
||||
.setNativeShape(ArrayRef<int64_t>{2, 2})
|
||||
.setFilterConstraint(
|
||||
[](Operation *op) { return success(isa<AddFOp>(op)); }));
|
||||
.setFilterConstraint([](Operation *op) {
|
||||
return success(isa<AddFOp, vector::FMAOp>(op));
|
||||
}));
|
||||
|
||||
if (unrollBasedOnType) {
|
||||
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
|
||||
|
|
Loading…
Reference in New Issue