Add helper classes to declarative builders to help write end-to-end custom ops.

This CL adds the same helper classes that exist in the AST form of EDSCs to support a basic indexing notation and emit the proper load and store operations and capture MemRefViews as function arguments.

This CL also adds a wrapper class LoopNestBuilder to allow generic rank-agnostic loops over indices.

PiperOrigin-RevId: 237113755
This commit is contained in:
Nicolas Vasilache 2019-03-06 13:54:41 -08:00 committed by jpienaar
parent 4fc9b51727
commit 7c0b9e8b62
7 changed files with 407 additions and 2 deletions

View File

@ -33,6 +33,7 @@ namespace edsc {
struct index_t {
explicit index_t(int64_t v) : v(v) {}
explicit operator int64_t() { return v; }
int64_t v;
};
@ -147,6 +148,39 @@ public:
ValueHandle operator()(ArrayRef<ValueHandle> stmts);
};
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
/// explicitly writing all the loops in a nest. This simple functionality is
/// also useful to write rank-agnostic custom ops.
///
/// Usage:
///
/// ```c++
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})({
/// ...
/// });
/// ```
///
/// ```c++
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})({
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})({
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})({
/// ...
/// }),
/// }),
/// });
/// ```
class LoopNestBuilder {
public:
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
ValueHandle operator()(ArrayRef<ValueHandle> stmts);
private:
SmallVector<LoopBuilder, 4> loops;
};
// This class exists solely to handle the C++ vexing parse case when
// trying to enter a Block that has already been constructed.
class Append {};

View File

@ -0,0 +1,162 @@
//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides helper classes and syntactic sugar for declarative builders.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EDSC_HELPERS_H_
#define MLIR_EDSC_HELPERS_H_
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Intrinsics.h"
namespace mlir {
namespace edsc {
class IndexedValue;
/// An IndexHandle is a simple wrapper around a ValueHandle.
/// IndexHandles are ubiquitous enough to justify a new type to allow simple
/// declarations without boilerplate such as:
///
/// ```c++
/// IndexHandle i, j, k;
/// ```
struct IndexHandle : public ValueHandle {
explicit IndexHandle()
: ValueHandle(ScopedContext::getBuilder()->getIndexType()) {}
explicit IndexHandle(index_t v) : ValueHandle(v) {}
explicit IndexHandle(Value *v) : ValueHandle(v) {
assert(v->getType() == ScopedContext::getBuilder()->getIndexType() &&
"Expected index type");
}
explicit IndexHandle(ValueHandle v) : ValueHandle(v) {}
};
/// A MemRefView represents the information required to step through a
/// MemRef. It has placeholders for non-contiguous tensors that fit within the
/// Fortran subarray model.
/// At the moment it can only capture a MemRef with an identity layout map.
// TODO(ntv): Support MemRefs with layoutMaps.
class MemRefView {
public:
explicit MemRefView(Value *v);
MemRefView(const MemRefView &) = default;
MemRefView &operator=(const MemRefView &) = default;
unsigned rank() const { return lbs.size(); }
unsigned fastestVarying() const { return rank() - 1; }
std::tuple<IndexHandle, IndexHandle, int64_t> range(unsigned idx) {
return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
}
private:
friend IndexedValue;
ValueHandle base;
SmallVector<IndexHandle, 8> lbs;
SmallVector<IndexHandle, 8> ubs;
SmallVector<int64_t, 8> steps;
};
ValueHandle operator+(ValueHandle v, IndexedValue i);
ValueHandle operator-(ValueHandle v, IndexedValue i);
ValueHandle operator*(ValueHandle v, IndexedValue i);
ValueHandle operator/(ValueHandle v, IndexedValue i);
/// This helper class is an abstraction over memref, that purely for sugaring
/// purposes and allows writing compact expressions such as:
///
/// ```mlir
/// IndexedValue A(...), B(...), C(...);
/// For(ivs, zeros, shapeA, ones, {
/// C(ivs) = A(ivs) + B(ivs)
/// });
/// ```
///
/// Assigning to an IndexedValue emits an actual store operation, while using
/// converting an IndexedValue to a ValueHandle emits an actual load operation.
struct IndexedValue {
explicit IndexedValue(MemRefView &v, llvm::ArrayRef<ValueHandle> indices = {})
: view(v), indices(indices.begin(), indices.end()) {}
IndexedValue(const IndexedValue &rhs) = default;
IndexedValue &operator=(const IndexedValue &rhs) = default;
/// Returns a new `IndexedValue`.
IndexedValue operator()(llvm::ArrayRef<ValueHandle> indices = {}) {
return IndexedValue(view, indices);
}
/// Emits a `store`.
// NOLINTNEXTLINE: unconventional-assign-operator
ValueHandle operator=(ValueHandle rhs) {
return intrinsics::STORE(rhs, getBase(), indices);
}
ValueHandle getBase() const { return view.base; }
/// Emits a `load` when converting to a ValueHandle.
explicit operator ValueHandle() {
return intrinsics::LOAD(getBase(), indices);
}
/// Operator overloadings.
ValueHandle operator+(ValueHandle e);
ValueHandle operator-(ValueHandle e);
ValueHandle operator*(ValueHandle e);
ValueHandle operator/(ValueHandle e);
ValueHandle operator+=(ValueHandle e);
ValueHandle operator-=(ValueHandle e);
ValueHandle operator*=(ValueHandle e);
ValueHandle operator/=(ValueHandle e);
ValueHandle operator+(IndexedValue e) {
return *this + static_cast<ValueHandle>(e);
}
ValueHandle operator-(IndexedValue e) {
return *this - static_cast<ValueHandle>(e);
}
ValueHandle operator*(IndexedValue e) {
return *this * static_cast<ValueHandle>(e);
}
ValueHandle operator/(IndexedValue e) {
return *this / static_cast<ValueHandle>(e);
}
ValueHandle operator+=(IndexedValue e) {
return this->operator+=(static_cast<ValueHandle>(e));
}
ValueHandle operator-=(IndexedValue e) {
return this->operator-=(static_cast<ValueHandle>(e));
}
ValueHandle operator*=(IndexedValue e) {
return this->operator*=(static_cast<ValueHandle>(e));
}
ValueHandle operator/=(IndexedValue e) {
return this->operator/=(static_cast<ValueHandle>(e));
}
private:
MemRefView &view;
llvm::SmallVector<ValueHandle, 8> indices;
};
} // namespace edsc
} // namespace mlir
#endif // MLIR_EDSC_HELPERS_H_

View File

@ -95,11 +95,22 @@ ValueHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch,
////////////////////////////////////////////////////////////////////////////////
// TODO(ntv): Intrinsics below this line should be TableGen'd.
////////////////////////////////////////////////////////////////////////////////
/// Builds an mlir::LoadOp with the proper `operands` that each must have
/// captured an mlir::Value*.
/// Returns a ValueHandle to the produced mlir::Value*.
ValueHandle LOAD(ValueHandle base, llvm::ArrayRef<ValueHandle> indices);
/// Builds an mlir::ReturnOp with the proper `operands` that each must have
/// captured an mlir::Value*.
/// Returns an empty ValueHandle.
ValueHandle RETURN(llvm::ArrayRef<ValueHandle> operands);
/// Builds an mlir::StoreOp with the proper `operands` that each must have
/// captured an mlir::Value*.
/// Returns an empty ValueHandle.
ValueHandle STORE(ValueHandle value, ValueHandle base,
llvm::ArrayRef<ValueHandle> indices);
} // namespace intrinsics
} // namespace edsc

View File

@ -162,6 +162,32 @@ ValueHandle mlir::edsc::LoopBuilder::operator()(ArrayRef<ValueHandle> stmts) {
return ValueHandle::null();
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
ArrayRef<int64_t> steps) {
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it),
std::get<3>(it));
}
}
ValueHandle
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<ValueHandle> stmts) {
// Iterate on the calling operator() on all the loops in the nest.
// The iteration order is from innermost to outermost because enter/exit needs
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
// occurs on calling operator()). The asymmetry is required for properly
// nesting imperfectly nested regions (see LoopBuilder::operator()).
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
(*lit)({});
}
return ValueHandle::null();
}
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
assert(bh && "Expected already captured BlockHandle");
enter(bh.getBlock());

110
mlir/lib/EDSC/Helpers.cpp Normal file
View File

@ -0,0 +1,110 @@
//===- Helpers.cpp - MLIR Declarative Helper Functionality ------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/StandardOps/Ops.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Builders.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Helpers.h"
using namespace mlir;
using namespace mlir::edsc;
static SmallVector<IndexHandle, 8> getMemRefSizes(Value *memRef) {
MemRefType memRefType = memRef->getType().cast<MemRefType>();
auto maps = memRefType.getAffineMaps();
assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) &&
"Layout maps not supported");
SmallVector<IndexHandle, 8> res;
res.reserve(memRefType.getShape().size());
const auto &shape = memRefType.getShape();
for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
if (shape[idx] == -1) {
res.push_back(IndexHandle(ValueHandle::create<DimOp>(memRef, idx)));
} else {
res.push_back(IndexHandle(static_cast<index_t>(shape[idx])));
}
}
return res;
}
mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) {
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
auto memrefSizeValues = getMemRefSizes(v);
for (auto &size : memrefSizeValues) {
lbs.push_back(IndexHandle(static_cast<index_t>(0)));
ubs.push_back(size);
steps.push_back(1);
}
}
/// Operator overloadings.
ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) {
using op::operator+;
return static_cast<ValueHandle>(*this) + e;
}
ValueHandle mlir::edsc::IndexedValue::operator-(ValueHandle e) {
using op::operator-;
return static_cast<ValueHandle>(*this) - e;
}
ValueHandle mlir::edsc::IndexedValue::operator*(ValueHandle e) {
using op::operator*;
return static_cast<ValueHandle>(*this) * e;
}
ValueHandle mlir::edsc::IndexedValue::operator/(ValueHandle e) {
using op::operator/;
return static_cast<ValueHandle>(*this) / e;
}
ValueHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) {
using op::operator+;
return intrinsics::STORE(*this + e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) {
using op::operator-;
return intrinsics::STORE(*this - e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) {
using op::operator*;
return intrinsics::STORE(*this * e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) {
using op::operator/;
return intrinsics::STORE(*this / e, getBase(), indices);
}
ValueHandle mlir::edsc::operator+(ValueHandle v, IndexedValue i) {
using op::operator+;
return v + static_cast<ValueHandle>(i);
}
ValueHandle mlir::edsc::operator-(ValueHandle v, IndexedValue i) {
using op::operator-;
return v - static_cast<ValueHandle>(i);
}
ValueHandle mlir::edsc::operator*(ValueHandle v, IndexedValue i) {
using op::operator*;
return v * static_cast<ValueHandle>(i);
}
ValueHandle mlir::edsc::operator/(ValueHandle v, IndexedValue i) {
using op::operator/;
return v / static_cast<ValueHandle>(i);
}

View File

@ -100,8 +100,21 @@ ValueHandle mlir::edsc::intrinsics::COND_BR(
////////////////////////////////////////////////////////////////////////////////
// TODO(ntv): Intrinsics below this line should be TableGen'd.
////////////////////////////////////////////////////////////////////////////////
ValueHandle
mlir::edsc::intrinsics::LOAD(ValueHandle base,
llvm::ArrayRef<ValueHandle> indices = {}) {
SmallVector<Value *, 4> ops(indices.begin(), indices.end());
return ValueHandle::create<LoadOp>(base.getValue(), ops);
}
ValueHandle mlir::edsc::intrinsics::RETURN(ArrayRef<ValueHandle> operands) {
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
return ValueHandle::create<ReturnOp>(ops);
}
ValueHandle
mlir::edsc::intrinsics::STORE(ValueHandle value, ValueHandle base,
llvm::ArrayRef<ValueHandle> indices = {}) {
SmallVector<Value *, 4> ops(indices.begin(), indices.end());
return ValueHandle::create<StoreOp>(value.getValue(), base.getValue(), ops);
}

View File

@ -20,9 +20,8 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
@ -317,6 +316,56 @@ TEST_FUNC(builder_cond_branch_eager) {
f->print(llvm::outs());
}
TEST_FUNC(builder_helpers) {
using namespace edsc;
using namespace edsc::intrinsics;
using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
ScopedContext scope(f.get());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), vC(f->getArgument(2));
IndexedValue A(vA), B(vB), C(vC);
IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
int64_t step0, step1, step2;
std::tie(lb0, ub0, step0) = vA.range(0);
std::tie(lb1, ub1, step1) = vA.range(1);
std::tie(lb2, ub2, step2) = vA.range(2);
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({
LoopBuilder(&k1, lb2, ub2, step2)({
C({i, j, k1}) = f7 + A({i, j, k1}) + B({i, j, k1}),
}),
LoopBuilder(&k2, lb2, ub2, step2)({
C({i, j, k2}) += A({i, j, k2}) + B({i, j, k2}),
}),
});
// CHECK-LABEL: @builder_helpers
// CHECK: for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
// CHECK-NEXT: for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
// CHECK-NEXT: for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
// CHECK-NEXT: [[a:%.*]] = load %arg0[%i0, %i1, %i2] : memref<?x?x?xf32>
// CHECK-NEXT: [[b:%.*]] = addf {{.*}}, [[a]] : f32
// CHECK-NEXT: [[c:%.*]] = load %arg1[%i0, %i1, %i2] : memref<?x?x?xf32>
// CHECK-NEXT: [[d:%.*]] = addf [[b]], [[c]] : f32
// CHECK-NEXT: store [[d]], %arg2[%i0, %i1, %i2] : memref<?x?x?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: for %i3 = (d0) -> (d0)(%c0_1) to (d0) -> (d0)(%2) {
// CHECK-NEXT: [[a:%.*]] = load %arg1[%i0, %i1, %i3] : memref<?x?x?xf32>
// CHECK-NEXT: [[b:%.*]] = load %arg0[%i0, %i1, %i3] : memref<?x?x?xf32>
// CHECK-NEXT: [[c:%.*]] = addf [[b]], [[a]] : f32
// CHECK-NEXT: [[d:%.*]] = load %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
// CHECK-NEXT: [[e:%.*]] = addf [[d]], [[c]] : f32
// CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
// clang-format on
f->print(llvm::outs());
}
int main() {
RUN_TESTS();
return 0;