forked from OSchip/llvm-project
Lower linalg.indexed_generic with libcall to LLVM.
PiperOrigin-RevId: 283328994
This commit is contained in:
parent
d5e627f84b
commit
9630fcbc52
|
@ -340,6 +340,28 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename LinalgOp>
|
||||
static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
|
||||
return SmallVector<Type, 4>{op->getOperandTypes()};
|
||||
}
|
||||
|
||||
template <>
|
||||
SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
|
||||
auto ctx = op->getContext();
|
||||
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
||||
auto numLoops = indexedGenericOp.getNumLoops();
|
||||
|
||||
SmallVector<Type, 4> result;
|
||||
result.reserve(numLoops + op->getNumOperands());
|
||||
for (unsigned i = 0; i < numLoops; ++i) {
|
||||
result.push_back(IndexType::get(ctx));
|
||||
}
|
||||
for (auto type : op->getOperandTypes()) {
|
||||
result.push_back(type);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
|
||||
// If the library function does not exist, insert a declaration.
|
||||
template <typename LinalgOp>
|
||||
|
@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|||
return fnNameAttr;
|
||||
}
|
||||
|
||||
SmallVector<Type, 4> inputTypes(op->getOperandTypes());
|
||||
SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
|
||||
assert(op->getNumResults() == 0 &&
|
||||
"Library call for linalg operation can be generated only for ops that "
|
||||
"have void return types");
|
||||
|
@ -430,6 +452,40 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern specialization for IndexedGenericOp.
|
||||
template <>
|
||||
class LinalgOpConversion<IndexedGenericOp>
|
||||
: public OpRewritePattern<IndexedGenericOp> {
|
||||
public:
|
||||
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(IndexedGenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto libraryCallName =
|
||||
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return this->matchFailure();
|
||||
|
||||
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
|
||||
// IndexedGenericOp is tiled.
|
||||
auto zero = rewriter.create<mlir::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
||||
auto numLoops = indexedGenericOp.getNumLoops();
|
||||
SmallVector<Value *, 4> operands;
|
||||
operands.reserve(numLoops + op.getNumOperands());
|
||||
for (unsigned i = 0; i < numLoops; ++i) {
|
||||
operands.push_back(zero);
|
||||
}
|
||||
for (auto operand : op.getOperands()) {
|
||||
operands.push_back(operand);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
||||
ArrayRef<Type>{}, operands);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
|
||||
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
|
||||
/// This interplays together with TransposeOpConversion and
|
||||
|
|
|
@ -141,7 +141,7 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
|
|||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "some_external_function_name_for_vector_outerproduct_matmul"
|
||||
library_call = "external_outerproduct_matmul"
|
||||
}
|
||||
|
||||
!vector_type_A = type vector<4xf32>
|
||||
|
@ -162,7 +162,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul_vec_impl(
|
||||
// CHECK: llvm.call @some_external_function_name_for_vector_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
|
||||
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
|
||||
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
|
||||
|
@ -172,3 +172,25 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
|
|||
// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
|
||||
// LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
|
||||
// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
|
||||
|
||||
|
||||
#indexed_matmul_trait = {
|
||||
n_views = [2, 1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "external_indexed_outerproduct_matmul"
|
||||
}
|
||||
func @matmul_vec_indexed(%A: !matrix_type_A,
|
||||
%B: !matrix_type_B,
|
||||
%C: !matrix_type_C) {
|
||||
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
|
||||
^bb0(%i: index, %j: index, %k: index,
|
||||
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
|
||||
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
|
||||
linalg.yield %d: !vector_type_C
|
||||
} : !matrix_type_A, !matrix_type_B, !matrix_type_C
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul_vec_indexed(
|
||||
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
|
||||
|
|
Loading…
Reference in New Issue