forked from OSchip/llvm-project
Apply a level of sugaring to the linalg.generic EDSC - NFC
Make the declarative C++ builder API simpler to use so we can start chaining these ops together. PiperOrigin-RevId: 285496266
This commit is contained in:
parent
7ac42fa26e
commit
200beb8446
|
@ -22,20 +22,82 @@
|
|||
#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
|
||||
#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
|
||||
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
namespace mlir {
|
||||
class BlockArgument;
|
||||
namespace edsc {
|
||||
|
||||
enum class IterType { Parallel, Reduction };
|
||||
|
||||
inline StringRef toString(IterType t) {
|
||||
switch (t) {
|
||||
case IterType::Parallel:
|
||||
return getParallelIteratorTypeName();
|
||||
case IterType::Reduction:
|
||||
return getParallelIteratorTypeName();
|
||||
default:
|
||||
llvm_unreachable("Unsupport IterType");
|
||||
}
|
||||
}
|
||||
|
||||
/// A StructuredIndexed represents a captured value that can be indexed and
|
||||
/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index
|
||||
/// expressions such as:
|
||||
///
|
||||
/// ```
|
||||
/// StructuredIndexed A(vA), B(vB), C(vC);
|
||||
/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
|
||||
/// ```
|
||||
struct StructuredIndexed {
|
||||
StructuredIndexed(Value *v) : value(v) {}
|
||||
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
|
||||
return StructuredIndexed(value, indexings);
|
||||
}
|
||||
|
||||
operator Value *() const /* implicit */ { return value; }
|
||||
ArrayRef<AffineExpr> getExprs() { return exprs; }
|
||||
|
||||
private:
|
||||
StructuredIndexed(Value *v, ArrayRef<AffineExpr> indexings)
|
||||
: value(v), exprs(indexings.begin(), indexings.end()) {
|
||||
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
|
||||
}
|
||||
StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
|
||||
: StructuredIndexed(v.getValue(), indexings) {}
|
||||
|
||||
Value *value;
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
};
|
||||
|
||||
inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
|
||||
|
||||
/// EDSC entry point to build linalg.generic operations programmatically.
|
||||
Operation *makeLinalgGenericOp(
|
||||
ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
|
||||
ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
|
||||
ArrayRef<StringRef> iteratorTypes,
|
||||
decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder);
|
||||
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
||||
ArrayRef<StructuredIndexed> outputs,
|
||||
decltype(defaultRegionBuilder) regionBuilder = defaultRegionBuilder,
|
||||
ArrayRef<Value *> otherValues = {},
|
||||
ArrayRef<Attribute> otherAttributes = {});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EDSC builders for linalg generic operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// TODO(ntv): In the future we should tie these implementations to something in
|
||||
/// Tablegen that generates the proper interfaces and the proper sugared named
|
||||
/// ops.
|
||||
|
||||
/// Build a linalg.generic that represents C = A * B in the current
|
||||
/// ScopedContext.
|
||||
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
|
||||
|
||||
template <typename Container> Operation *linalg_matmul(Container values) {
|
||||
assert(values.size() == 3 && "Expected exactly 3 values");
|
||||
return linalg_matmul(values[0], values[1], values[2]);
|
||||
}
|
||||
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,50 +15,84 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
|
||||
unsigned &pos) {
|
||||
for (auto sidx : structuredIndices) {
|
||||
for (auto expr : sidx.getExprs()) {
|
||||
expr.walk([&pos](AffineExpr e) {
|
||||
if (auto d = e.dyn_cast<AffineDimExpr>())
|
||||
pos = std::max(pos, d.getPosition());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Operation *mlir::edsc::makeLinalgGenericOp(
|
||||
ArrayRef<AffineExpr> indices, ArrayRef<ArrayRef<AffineExpr>> mapExpressions,
|
||||
ArrayRef<Value *> inputViews, ArrayRef<Value *> outputViews,
|
||||
ArrayRef<StringRef> iteratorTypes,
|
||||
decltype(defaultRegionBuilder) regionBuilder) {
|
||||
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
|
||||
ArrayRef<StructuredIndexed> outputs,
|
||||
decltype(defaultRegionBuilder) regionBuilder, ArrayRef<Value *> otherValues,
|
||||
ArrayRef<Attribute> otherAttributes) {
|
||||
auto &builder = edsc::ScopedContext::getBuilder();
|
||||
auto *ctx = builder.getContext();
|
||||
unsigned nInputs = inputs.size();
|
||||
unsigned nOutputs = outputs.size();
|
||||
unsigned rank = 0;
|
||||
getMaxDimIndex(inputs, rank);
|
||||
getMaxDimIndex(outputs, rank);
|
||||
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
maps.reserve(mapExpressions.size());
|
||||
for (auto exprs : mapExpressions)
|
||||
maps.push_back(AffineMap::get(indices.size(), 0, exprs));
|
||||
maps.reserve(nInputs + nOutputs);
|
||||
for (auto in : inputs)
|
||||
maps.push_back(
|
||||
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, in.getExprs()));
|
||||
for (auto out : outputs)
|
||||
maps.push_back(
|
||||
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, out.getExprs()));
|
||||
|
||||
SmallVector<Value *, 4> views;
|
||||
views.reserve(inputViews.size() + outputViews.size());
|
||||
views.append(inputViews.begin(), inputViews.end());
|
||||
views.append(outputViews.begin(), outputViews.end());
|
||||
unsigned nViews = nInputs + nOutputs;
|
||||
SmallVector<Value *, 4> values;
|
||||
values.reserve(nViews);
|
||||
values.append(inputs.begin(), inputs.end());
|
||||
values.append(outputs.begin(), outputs.end());
|
||||
|
||||
auto iteratorStrTypes = functional::map(toString, iteratorTypes);
|
||||
// clang-format off
|
||||
auto *op =
|
||||
edsc::ScopedContext::getBuilder()
|
||||
.create<linalg::GenericOp>(
|
||||
edsc::ScopedContext::getLocation(), views,
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), inputViews.size()),
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), outputViews.size()),
|
||||
edsc::ScopedContext::getLocation(),
|
||||
values,
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
|
||||
builder.getAffineMapArrayAttr(maps),
|
||||
builder.getStrArrayAttr(iteratorTypes), StringAttr() /*doc*/,
|
||||
FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/
|
||||
builder.getStrArrayAttr(iteratorStrTypes),
|
||||
StringAttr() /*doc*/,
|
||||
FlatSymbolRefAttr() /*fun*/,
|
||||
StringAttr() /*library_call*/
|
||||
/* TODO: other attributes in op */
|
||||
)
|
||||
.getOperation();
|
||||
// clang-format on
|
||||
|
||||
using namespace edsc;
|
||||
SmallVector<Type, 4> blockTypes;
|
||||
blockTypes.reserve(views.size());
|
||||
for (auto *v : views)
|
||||
blockTypes.push_back(getElementTypeOrSelf(v));
|
||||
blockTypes.reserve(values.size());
|
||||
for (auto it : llvm::enumerate(values))
|
||||
blockTypes.push_back((it.index() < nViews)
|
||||
? getElementTypeOrSelf(it.value())
|
||||
: it.value()->getType());
|
||||
|
||||
assert(op->getRegions().front().empty());
|
||||
op->getRegions().front().push_front(new Block);
|
||||
|
@ -70,3 +104,24 @@ Operation *mlir::edsc::makeLinalgGenericOp(
|
|||
[&] { regionBuilder(b.getBlock()->getArguments()); });
|
||||
return op;
|
||||
}
|
||||
|
||||
using linalg_yield = OperationBuilder<linalg::YieldOp>;
|
||||
|
||||
Operation *mlir::edsc::linalg_matmul(ValueHandle vA, ValueHandle vB,
|
||||
ValueHandle vC) {
|
||||
// clang-format off
|
||||
AffineExpr m, n, k;
|
||||
bindDims(ScopedContext::getContext(), m, n, k);
|
||||
StructuredIndexed A(vA), B(vB), C(vC);
|
||||
return makeLinalgGenericOp(
|
||||
{IterType::Parallel, IterType::Parallel, IterType::Reduction},
|
||||
{A({m, n}), B({k, n})},
|
||||
{C({m, n})},
|
||||
[](ArrayRef<BlockArgument *> args) {
|
||||
using edsc::op::operator*;
|
||||
using edsc::op::operator+;
|
||||
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
||||
linalg_yield((c + a * b).getValue());
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -821,32 +821,15 @@ TEST_FUNC(affine_if_op) {
|
|||
// clang-format on
|
||||
TEST_FUNC(linalg_matmul) {
|
||||
using namespace edsc;
|
||||
using namespace edsc::intrinsics;
|
||||
using namespace edsc::op;
|
||||
using linalg_yield = OperationBuilder<linalg::YieldOp>;
|
||||
|
||||
auto f32Type = FloatType::getF32(&globalContext());
|
||||
auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
|
||||
auto f =
|
||||
makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
// clang-format off
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
Value *A(f.getArgument(0)), *B(f.getArgument(1)), *C(f.getArgument(2));
|
||||
AffineExpr m, n, k;
|
||||
bindDims(f.getContext(), m, n, k);
|
||||
makeLinalgGenericOp(
|
||||
{m, n, k},
|
||||
{{m, n}, {k, n}, {m, n}},
|
||||
{A, B},
|
||||
{C},
|
||||
{"parallel", "parallel", "reduction"},
|
||||
[](ArrayRef<BlockArgument *> args) {
|
||||
ValueHandle a(args[0]), b(args[1]), c(args[2]);
|
||||
linalg_yield((c + a * b).getValue());
|
||||
});
|
||||
// clang-format on
|
||||
linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())));
|
||||
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
|
|
Loading…
Reference in New Issue