From 200beb84461bd249589913a3d89898a9c6e588b9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 13 Dec 2019 16:35:49 -0800 Subject: [PATCH] Apply a level of sugaring to the linalg.generic EDSC - NFC Make the declarative C++ builder API simpler to use so we can start chaining these ops together. PiperOrigin-RevId: 285496266 --- .../mlir/Dialect/Linalg/EDSC/Builders.h | 72 +++++++++++++- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 95 +++++++++++++++---- mlir/test/EDSC/builder-api-test.cpp | 19 +--- 3 files changed, 143 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index 3618ec14468e..00da1d68cf25 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -22,20 +22,82 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" namespace mlir { class BlockArgument; namespace edsc { +enum class IterType { Parallel, Reduction }; + +inline StringRef toString(IterType t) { + switch (t) { + case IterType::Parallel: + return getParallelIteratorTypeName(); + case IterType::Reduction: + return getParallelIteratorTypeName(); + default: + llvm_unreachable("Unsupport IterType"); + } +} + +/// A StructuredIndexed represents a captured value that can be indexed and +/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index +/// expressions such as: +/// +/// ``` +/// StructuredIndexed A(vA), B(vB), C(vC); +/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); +/// ``` +struct StructuredIndexed { + StructuredIndexed(Value *v) : value(v) {} + StructuredIndexed operator()(ArrayRef indexings) { + return StructuredIndexed(value, indexings); + } + + operator Value *() const /* implicit */ { return value; } + ArrayRef getExprs() { return exprs; } + +private: + StructuredIndexed(Value *v, ArrayRef indexings) + : value(v), exprs(indexings.begin(), indexings.end()) { + assert(v->getType().isa() && "MemRefType expected"); + } + StructuredIndexed(ValueHandle v, ArrayRef indexings) + : StructuredIndexed(v.getValue(), indexings) {} + + Value *value; + SmallVector exprs; +}; + inline void defaultRegionBuilder(ArrayRef args) {} -/// EDSC entry point to build linalg.generic operations programmatically. Operation *makeLinalgGenericOp( - ArrayRef indices, ArrayRef> mapExpressions, - ArrayRef inputViews, ArrayRef outputViews, - ArrayRef iteratorTypes, - decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder); + ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef outputs, + decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder, + ArrayRef otherValues = {}, + ArrayRef otherAttributes = {}); + +//===----------------------------------------------------------------------===// +// EDSC builders for linalg generic operations. +//===----------------------------------------------------------------------===// + +/// TODO(ntv): In the future we should tie these implementations to something in +/// Tablegen that generates the proper interfaces and the proper sugared named +/// ops. + +/// Build a linalg.generic that represents C = A * B in the current +/// ScopedContext. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); + +template Operation *linalg_matmul(Container values) { + assert(values.size() == 3 && "Expected exactly 3 values"); + return linalg_matmul(values[0], values[1], values[2]); +} } // namespace edsc } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 606160b9b14c..3daeafe00ca3 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -15,50 +15,84 @@ // limitations under the License. // ============================================================================= -#include "mlir/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/Functional.h" using namespace mlir; using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; + +static void getMaxDimIndex(ArrayRef structuredIndices, + unsigned &pos) { + for (auto sidx : structuredIndices) { + for (auto expr : sidx.getExprs()) { + expr.walk([&pos](AffineExpr e) { + if (auto d = e.dyn_cast()) + pos = std::max(pos, d.getPosition()); + }); + } + } +} Operation *mlir::edsc::makeLinalgGenericOp( - ArrayRef indices, ArrayRef> mapExpressions, - ArrayRef inputViews, ArrayRef outputViews, - ArrayRef iteratorTypes, - decltype(defaultRegionBuilder) regionBuilder) { + ArrayRef iteratorTypes, ArrayRef inputs, + ArrayRef outputs, + decltype(defaultRegionBuilder) regionBuilder, ArrayRef otherValues, + ArrayRef otherAttributes) { auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); + unsigned nInputs = inputs.size(); + unsigned nOutputs = outputs.size(); + unsigned rank = 0; + getMaxDimIndex(inputs, rank); + getMaxDimIndex(outputs, rank); SmallVector maps; - maps.reserve(mapExpressions.size()); - for (auto exprs : mapExpressions) - maps.push_back(AffineMap::get(indices.size(), 0, exprs)); + maps.reserve(nInputs + nOutputs); + for (auto in : inputs) + maps.push_back( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs())); + for (auto out : outputs) + maps.push_back( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs())); - SmallVector views; - views.reserve(inputViews.size() + outputViews.size()); - views.append(inputViews.begin(), inputViews.end()); - views.append(outputViews.begin(), outputViews.end()); + unsigned nViews = nInputs + nOutputs; + SmallVector values; + values.reserve(nViews); + values.append(inputs.begin(), inputs.end()); + values.append(outputs.begin(), outputs.end()); + auto iteratorStrTypes = functional::map(toString, iteratorTypes); + // clang-format off auto *op = edsc::ScopedContext::getBuilder() .create( - edsc::ScopedContext::getLocation(), views, - IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()), - IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()), + edsc::ScopedContext::getLocation(), + values, + IntegerAttr::get(IntegerType::get(64, ctx), nInputs), + IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), builder.getAffineMapArrayAttr(maps), - builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/, - FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/ + builder.getStrArrayAttr(iteratorStrTypes), + StringAttr() /*doc*/, + FlatSymbolRefAttr() /*fun*/, + StringAttr() /*library_call*/ + /* TODO: other attributes in op */ ) .getOperation(); + // clang-format on using namespace edsc; SmallVector blockTypes; - blockTypes.reserve(views.size()); - for (auto *v : views) - blockTypes.push_back(getElementTypeOrSelf(v)); + blockTypes.reserve(values.size()); + for (auto it : llvm::enumerate(values)) + blockTypes.push_back((it.index() < nViews) + ? getElementTypeOrSelf(it.value()) + : it.value()->getType()); assert(op->getRegions().front().empty()); op->getRegions().front().push_front(new Block); @@ -70,3 +104,24 @@ Operation *mlir::edsc::makeLinalgGenericOp( [&] { regionBuilder(b.getBlock()->getArguments()); }); return op; } + +using linalg_yield = OperationBuilder; + +Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB, + ValueHandle vC) { + // clang-format off + AffineExpr m, n, k; + bindDims(ScopedContext::getContext(), m, n, k); + StructuredIndexed A(vA), B(vB), C(vC); + return makeLinalgGenericOp( + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, + {A({m, n}), B({k, n})}, + {C({m, n})}, + [](ArrayRef args) { + using edsc::op::operator*; + using edsc::op::operator+; + ValueHandle a(args[0]), b(args[1]), c(args[2]); + linalg_yield((c + a * b).getValue()); + }); + // clang-format on +} diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index dc17305f4df7..abd1eb0cac63 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -821,32 +821,15 @@ TEST_FUNC(affine_if_op) { // clang-format on TEST_FUNC(linalg_matmul) { using namespace edsc; - using namespace edsc::intrinsics; - using namespace edsc::op; - using linalg_yield = OperationBuilder; auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); auto f = makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType}); - // clang-format off OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - Value *A(f.getArgument(0)), *B(f.getArgument(1)), *C(f.getArgument(2)); - AffineExpr m, n, k; - bindDims(f.getContext(), m, n, k); - makeLinalgGenericOp( - {m, n, k}, - {{m, n}, {k, n}, {m, n}}, - {A, B}, - {C}, - {"parallel", "parallel", "reduction"}, - [](ArrayRef args) { - ValueHandle a(args[0]), b(args[1]), c(args[2]); - linalg_yield((c + a * b).getValue()); - }); - // clang-format on + linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments()))); f.print(llvm::outs()); f.erase();