Lower linalg.indexed_generic with libcall to LLVM.

PiperOrigin-RevId: 283328994
This commit is contained in:
Alexander Belyaev 2019-12-02 06:30:19 -08:00 committed by A. Unique TensorFlower
parent d5e627f84b
commit 9630fcbc52
2 changed files with 81 additions and 3 deletions

View File

@ -340,6 +340,28 @@ public:
}
};
template <typename LinalgOp>
static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
return SmallVector<Type, 4>{op->getOperandTypes()};
}
template <>
SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
auto ctx = op->getContext();
auto indexedGenericOp = cast<IndexedGenericOp>(op);
auto numLoops = indexedGenericOp.getNumLoops();
SmallVector<Type, 4> result;
result.reserve(numLoops + op->getNumOperands());
for (unsigned i = 0; i < numLoops; ++i) {
result.push_back(IndexType::get(ctx));
}
for (auto type : op->getOperandTypes()) {
result.push_back(type);
}
return result;
}
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
SmallVector<Type, 4> inputTypes(op->getOperandTypes());
SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
@ -430,6 +452,40 @@ public:
}
};
/// Conversion pattern specialization for IndexedGenericOp.
template <>
class LinalgOpConversion<IndexedGenericOp>
: public OpRewritePattern<IndexedGenericOp> {
public:
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(IndexedGenericOp op,
PatternRewriter &rewriter) const override {
auto libraryCallName =
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
if (!libraryCallName)
return this->matchFailure();
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
// IndexedGenericOp is tiled.
auto zero = rewriter.create<mlir::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
auto indexedGenericOp = cast<IndexedGenericOp>(op);
auto numLoops = indexedGenericOp.getNumLoops();
SmallVector<Value *, 4> operands;
operands.reserve(numLoops + op.getNumOperands());
for (unsigned i = 0; i < numLoops; ++i) {
operands.push_back(zero);
}
for (auto operand : op.getOperands()) {
operands.push_back(operand);
}
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
ArrayRef<Type>{}, operands);
return this->matchSuccess();
}
};
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
/// This interplays together with TransposeOpConversion and

View File

@ -141,7 +141,7 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "some_external_function_name_for_vector_outerproduct_matmul"
library_call = "external_outerproduct_matmul"
}
!vector_type_A = type vector<4xf32>
@ -162,7 +162,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
return
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: llvm.call @some_external_function_name_for_vector_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
@ -172,3 +172,25 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
// LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
#indexed_matmul_trait = {
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_indexed_outerproduct_matmul"
}
func @matmul_vec_indexed(%A: !matrix_type_A,
%B: !matrix_type_B,
%C: !matrix_type_C) {
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
^bb0(%i: index, %j: index, %k: index,
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
} : !matrix_type_A, !matrix_type_B, !matrix_type_C
return
}
// CHECK-LABEL: func @matmul_vec_indexed(
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()