forked from OSchip/llvm-project
Linalg portion of the tutorial - part 4
This CL adds declarative tiling support in the linalg dialect by providing: 1. loop tiling on linalg ops by simply calling into mlir::tile 2. view tiling on linalg ops by: a. computing the subview between for each tile dimension based on the loop tile size and the mapping of loops to operand ranges. b. declaring that the tiled form of a tensorcontraction is the same tensorcontraction on subviews, which essentially gives us a recursive form. Point 2.b is potentially subject to change in the future. -- PiperOrigin-RevId: 242058658
This commit is contained in:
parent
fde21c6faf
commit
92df395068
|
@ -70,8 +70,13 @@ public:
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
// Used in Linalg3 and later.
|
// Used in Linalg3 and later.
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
mlir::Value *getInputView(unsigned i);
|
mlir::Value *getInputView(unsigned viewIndex);
|
||||||
mlir::Value *getOutputView(unsigned i);
|
mlir::Value *getOutputView(unsigned viewIndex);
|
||||||
|
mlir::Value *getView(unsigned viewIndex) {
|
||||||
|
return viewIndex < getNumInputs()
|
||||||
|
? getInputView(viewIndex)
|
||||||
|
: getOutputView(viewIndex - getNumInputs());
|
||||||
|
}
|
||||||
|
|
||||||
/// Each op is responsible for declaring how it lowers itself to scalar form,
|
/// Each op is responsible for declaring how it lowers itself to scalar form,
|
||||||
/// given the enclosing parallel and reduction induction variables.
|
/// given the enclosing parallel and reduction induction variables.
|
||||||
|
@ -86,10 +91,9 @@ public:
|
||||||
/// ConcreteOp implementation, the resulting map must match those.
|
/// ConcreteOp implementation, the resulting map must match those.
|
||||||
/// In favorable cases, this can be calculated by an analysis but specifying
|
/// In favorable cases, this can be calculated by an analysis but specifying
|
||||||
/// it explicitly is not expensive and generalizes to cases where an analysis
|
/// it explicitly is not expensive and generalizes to cases where an analysis
|
||||||
/// is not available.
|
/// is not available. For details, see the description of
|
||||||
/// For details, see the description of loopsToOperandRangesMap in each
|
/// loopsToOperandRangeMaps in each ConcreteOp.
|
||||||
/// ConcreteOp
|
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||||
mlir::AffineMap loopsToOperandRangesMap();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
/// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
|
||||||
|
@ -135,7 +139,7 @@ public:
|
||||||
/// (d0) -> (d0, d0)(%k)
|
/// (d0) -> (d0, d0)(%k)
|
||||||
/// And the operands ranges are:
|
/// And the operands ranges are:
|
||||||
/// (%k, %k)
|
/// (%k, %k)
|
||||||
mlir::AffineMap loopsToOperandRangesMap();
|
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||||
|
|
||||||
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
|
/// Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
|
||||||
/// to:
|
/// to:
|
||||||
|
@ -195,7 +199,7 @@ public:
|
||||||
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
/// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||||
/// And the operands ranges are:
|
/// And the operands ranges are:
|
||||||
/// (%m, %k, %k, %m)
|
/// (%m, %k, %k, %m)
|
||||||
mlir::AffineMap loopsToOperandRangesMap();
|
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||||
|
|
||||||
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
|
/// Given an enclosing parallel loop with iv `i` and an enclosing parallel
|
||||||
/// loop with iv `r_j`, emits MLIR corresponding to:
|
/// loop with iv `r_j`, emits MLIR corresponding to:
|
||||||
|
@ -255,7 +259,7 @@ public:
|
||||||
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
/// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||||
/// And the operands ranges are:
|
/// And the operands ranges are:
|
||||||
/// (%m, %k, %k, %n, %m, %n)
|
/// (%m, %k, %k, %n, %m, %n)
|
||||||
mlir::AffineMap loopsToOperandRangesMap();
|
llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
|
||||||
|
|
||||||
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
|
/// Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
|
||||||
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
|
/// reduction loop with iv `r_k`, emits MLIR corresponding to:
|
||||||
|
|
|
@ -29,20 +29,20 @@
|
||||||
|
|
||||||
template <class ConcreteOp>
|
template <class ConcreteOp>
|
||||||
mlir::Value *
|
mlir::Value *
|
||||||
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned i) {
|
linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
|
||||||
return *(getInputs().begin() + i);
|
return *(getInputs().begin() + viewIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ConcreteOp>
|
template <class ConcreteOp>
|
||||||
mlir::Value *
|
mlir::Value *
|
||||||
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
|
linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
|
||||||
return *(getOutputs().begin() + i);
|
return *(getOutputs().begin() + viewIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ConcreteOp>
|
template <class ConcreteOp>
|
||||||
mlir::AffineMap
|
llvm::SmallVector<mlir::AffineMap, 8>
|
||||||
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangesMap() {
|
linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
|
||||||
return static_cast<ConcreteOp *>(this)->loopsToOperandRangesMap();
|
return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class ConcreteOp>
|
template <class ConcreteOp>
|
||||||
|
@ -56,7 +56,24 @@ void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
|
||||||
template <class ConcreteOp>
|
template <class ConcreteOp>
|
||||||
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
mlir::AffineMap linalg::operandRangesToLoopsMap(
|
||||||
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
|
||||||
return inverseSubMap(tensorContraction.loopsToOperandRangesMap());
|
mlir::AffineMap current;
|
||||||
|
// Individual submaps may not be invertible but their union must be invertible
|
||||||
|
// by construction.
|
||||||
|
for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
|
||||||
|
if (!m)
|
||||||
|
continue;
|
||||||
|
if (!current) {
|
||||||
|
current = m;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
|
||||||
|
current.getResults().end());
|
||||||
|
results.append(m.getResults().begin(), m.getResults().end());
|
||||||
|
current = mlir::AffineMap::get(
|
||||||
|
std::max(current.getNumDims(), m.getNumDims()),
|
||||||
|
current.getNumSymbols() + m.getNumSymbols(), results, {});
|
||||||
|
}
|
||||||
|
return inverseSubMap(current);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the ranges from a given ViewOp or SliceOp.
|
// Extract the ranges from a given ViewOp or SliceOp.
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Ideally all these functions would go in an Analysis but until
|
/// Ideally all these functions would go in an Analysis but as long as
|
||||||
/// TensorContractionBase is templated, they need to remain close enough.
|
/// TensorContractionBase is templated, they need to remain close enough.
|
||||||
///
|
///
|
||||||
|
|
||||||
|
|
|
@ -19,14 +19,40 @@
|
||||||
#define LINALG3_TRANSFORMS_H_
|
#define LINALG3_TRANSFORMS_H_
|
||||||
|
|
||||||
#include "linalg2/Transforms.h"
|
#include "linalg2/Transforms.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/Optional.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class AffineForOp;
|
||||||
|
class AffineMap;
|
||||||
class Function;
|
class Function;
|
||||||
class FunctionPassBase;
|
class FunctionPassBase;
|
||||||
|
class Operation;
|
||||||
|
class Value;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
|
|
||||||
|
struct RangeParts {
|
||||||
|
explicit RangeParts(unsigned reserved);
|
||||||
|
RangeParts(llvm::ArrayRef<mlir::Value *> ranges);
|
||||||
|
llvm::SmallVector<mlir::Value *, 4> makeRanges();
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value *, 4> mins;
|
||||||
|
llvm::SmallVector<mlir::Value *, 4> maxes;
|
||||||
|
llvm::SmallVector<mlir::Value *, 4> steps;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::Value *
|
||||||
|
makeFoldedComposedAffineApply(mlir::AffineMap map,
|
||||||
|
llvm::ArrayRef<mlir::Value *> operandsRef);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value *, 4>
|
||||||
|
makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
|
||||||
|
llvm::ArrayRef<mlir::Value *> ranges,
|
||||||
|
llvm::ArrayRef<mlir::Value *> tileSizes = {});
|
||||||
|
|
||||||
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
|
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
|
||||||
/// to only use linalg.view operations.
|
/// to only use linalg.view operations.
|
||||||
void composeSliceOps(mlir::Function *f);
|
void composeSliceOps(mlir::Function *f);
|
||||||
|
@ -35,6 +61,13 @@ void composeSliceOps(mlir::Function *f);
|
||||||
/// as linalg.matvec (resp. linalg.dot).
|
/// as linalg.matvec (resp. linalg.dot).
|
||||||
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
|
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
|
||||||
|
|
||||||
|
/// Operation-wise writing of linalg operations to loop form.
|
||||||
|
/// It is the caller's responsibility to erase the `op` if necessary.
|
||||||
|
/// This returns the enclosing loops around the body of `op` for further
|
||||||
|
/// composition of transformations.
|
||||||
|
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
|
||||||
|
writeAsLoops(mlir::Operation *op);
|
||||||
|
|
||||||
/// Traverses `f` and rewrites linalg operations in loop form.
|
/// Traverses `f` and rewrites linalg operations in loop form.
|
||||||
void lowerToLoops(mlir::Function *f);
|
void lowerToLoops(mlir::Function *f);
|
||||||
|
|
||||||
|
|
|
@ -39,14 +39,16 @@ using namespace linalg::intrinsics;
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
// Implementation of DotOp.
|
// Implementation of DotOp.
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
AffineMap linalg::DotOp::loopsToOperandRangesMap() {
|
SmallVector<AffineMap, 8> linalg::DotOp::loopsToOperandRangeMaps() {
|
||||||
// A(K), B(K), C()
|
// A(K), B(K), C()
|
||||||
assert(getRanges(*this).size() == 2);
|
assert(getRanges(*this).size() == 2);
|
||||||
auto *context = ScopedContext::getContext();
|
auto *context = ScopedContext::getContext();
|
||||||
auto d0 = getAffineDimExpr(0, context); // K
|
auto d0 = getAffineDimExpr(0, context); // K
|
||||||
// A(K), B(K), C()
|
// A(K), B(K), C()
|
||||||
// (d0) -> (d0, d0)(%k)
|
// (d0) -> (d0, d0)(%k)
|
||||||
return AffineMap::get(1, 0, {d0, d0}, {});
|
return SmallVector<AffineMap, 8>{AffineMap::get(1, 0, {d0}, {}), // A(K)
|
||||||
|
AffineMap::get(1, 0, {d0}, {}), // B(K)
|
||||||
|
AffineMap()}; // C()
|
||||||
}
|
}
|
||||||
|
|
||||||
void linalg::DotOp::emitScalarImplementation(
|
void linalg::DotOp::emitScalarImplementation(
|
||||||
|
@ -75,7 +77,7 @@ void linalg::DotOp::emitScalarImplementation(
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
// Implementation of MatvecOp.
|
// Implementation of MatvecOp.
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
|
SmallVector<AffineMap, 8> linalg::MatvecOp::loopsToOperandRangeMaps() {
|
||||||
// A(M, K), B(K), C(M)
|
// A(M, K), B(K), C(M)
|
||||||
assert(getRanges(*this).size() == 4);
|
assert(getRanges(*this).size() == 4);
|
||||||
auto *context = ScopedContext::getContext();
|
auto *context = ScopedContext::getContext();
|
||||||
|
@ -83,7 +85,10 @@ AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
|
||||||
auto d1 = getAffineDimExpr(1, context); // K
|
auto d1 = getAffineDimExpr(1, context); // K
|
||||||
// A(M, K), B(K), C(M)
|
// A(M, K), B(K), C(M)
|
||||||
// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
// (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
|
||||||
return AffineMap::get(2, 0, {d0, d1, d1, d0}, {});
|
return SmallVector<AffineMap, 8>{
|
||||||
|
AffineMap::get(2, 0, {d0, d1}, {}), // A(M, K)
|
||||||
|
AffineMap::get(2, 0, {d1}, {}), // B(K)
|
||||||
|
AffineMap::get(2, 0, {d0}, {})}; // C(M)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
|
||||||
|
@ -135,9 +140,9 @@ void linalg::MatvecOp::emitScalarImplementation(
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
// Op-specific Matmul.
|
// Implementation of Matmul.
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
|
SmallVector<AffineMap, 8> linalg::MatmulOp::loopsToOperandRangeMaps() {
|
||||||
// A(M, K), B(K, N), C(M, N)
|
// A(M, K), B(K, N), C(M, N)
|
||||||
assert(getRanges(*this).size() == 6);
|
assert(getRanges(*this).size() == 6);
|
||||||
auto *context = ScopedContext::getContext();
|
auto *context = ScopedContext::getContext();
|
||||||
|
@ -146,7 +151,11 @@ AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
|
||||||
auto d2 = getAffineDimExpr(2, context); // K
|
auto d2 = getAffineDimExpr(2, context); // K
|
||||||
// A(M, K), B(K, N), C(M, N):
|
// A(M, K), B(K, N), C(M, N):
|
||||||
// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
// (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
|
||||||
return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {});
|
return SmallVector<AffineMap, 8>{
|
||||||
|
AffineMap::get(3, 0, {d0, d2}, {}), // A(M, K)
|
||||||
|
AffineMap::get(3, 0, {d2, d1}, {}), // B(K, N)
|
||||||
|
AffineMap::get(3, 0, {d0, d1}, {}) // C(M, N)
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
|
// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
|
||||||
|
|
|
@ -71,8 +71,8 @@ static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value *makeFoldedComposedAffineApply(AffineMap map,
|
Value *linalg::makeFoldedComposedAffineApply(AffineMap map,
|
||||||
ArrayRef<Value *> operandsRef) {
|
ArrayRef<Value *> operandsRef) {
|
||||||
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
|
||||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||||
if (auto *v = tryFold(map, operands)) {
|
if (auto *v = tryFold(map, operands)) {
|
||||||
|
@ -83,18 +83,7 @@ static Value *makeFoldedComposedAffineApply(AffineMap map,
|
||||||
return b->create<AffineApplyOp>(loc, map, operands).getResult();
|
return b->create<AffineApplyOp>(loc, map, operands).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RangeParts {
|
linalg::RangeParts::RangeParts(unsigned reserved) {
|
||||||
explicit RangeParts(unsigned reserved);
|
|
||||||
RangeParts(ArrayRef<Value *> ranges);
|
|
||||||
|
|
||||||
SmallVector<Value *, 4> makeRanges();
|
|
||||||
|
|
||||||
SmallVector<Value *, 4> mins;
|
|
||||||
SmallVector<Value *, 4> maxes;
|
|
||||||
SmallVector<Value *, 4> steps;
|
|
||||||
};
|
|
||||||
|
|
||||||
RangeParts::RangeParts(unsigned reserved) {
|
|
||||||
mins.reserve(reserved);
|
mins.reserve(reserved);
|
||||||
maxes.reserve(reserved);
|
maxes.reserve(reserved);
|
||||||
steps.reserve(reserved);
|
steps.reserve(reserved);
|
||||||
|
@ -112,12 +101,12 @@ extractFromRanges(ArrayRef<Value *> ranges,
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
RangeParts::RangeParts(ArrayRef<Value *> ranges)
|
linalg::RangeParts::RangeParts(ArrayRef<Value *> ranges)
|
||||||
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
|
: mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
|
||||||
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
|
maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
|
||||||
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
|
steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
|
||||||
|
|
||||||
SmallVector<Value *, 4> RangeParts::makeRanges() {
|
SmallVector<Value *, 4> linalg::RangeParts::makeRanges() {
|
||||||
SmallVector<Value *, 4> res;
|
SmallVector<Value *, 4> res;
|
||||||
res.reserve(mins.size());
|
res.reserve(mins.size());
|
||||||
for (auto z : llvm::zip(mins, maxes, steps)) {
|
for (auto z : llvm::zip(mins, maxes, steps)) {
|
||||||
|
@ -149,14 +138,15 @@ SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
|
||||||
return makeGenericRangeParts(map, ranges).makeRanges();
|
return makeGenericRangeParts(map, ranges).makeRanges();
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<Value *, 4> makeGenericLoopRanges(
|
SmallVector<Value *, 4>
|
||||||
AffineMap operandRangesToLoopsMap, ArrayRef<Value *> ranges,
|
linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps,
|
||||||
llvm::Optional<ArrayRef<Value *>> tileSizes = llvm::None) {
|
ArrayRef<Value *> ranges,
|
||||||
RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges);
|
ArrayRef<Value *> tileSizes) {
|
||||||
if (!tileSizes.hasValue())
|
RangeParts res = makeGenericRangeParts(operandRangesToLoopMaps, ranges);
|
||||||
|
if (tileSizes.empty())
|
||||||
return res.makeRanges();
|
return res.makeRanges();
|
||||||
SmallVector<Value *, 4> tiledSteps;
|
SmallVector<Value *, 4> tiledSteps;
|
||||||
for (auto z : llvm::zip(res.steps, *tileSizes)) {
|
for (auto z : llvm::zip(res.steps, tileSizes)) {
|
||||||
auto *step = std::get<0>(z);
|
auto *step = std::get<0>(z);
|
||||||
auto tileSize = std::get<1>(z);
|
auto tileSize = std::get<1>(z);
|
||||||
auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
|
auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
|
||||||
|
@ -171,11 +161,12 @@ static SmallVector<Value *, 4> makeGenericLoopRanges(
|
||||||
|
|
||||||
template <class ContractionOp>
|
template <class ContractionOp>
|
||||||
static SmallVector<mlir::AffineForOp, 4>
|
static SmallVector<mlir::AffineForOp, 4>
|
||||||
writeAsLoops(ContractionOp contraction) {
|
writeContractionAsLoops(ContractionOp contraction) {
|
||||||
ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
|
ScopedContext scope(FuncBuilder(contraction.getOperation()),
|
||||||
contraction.getLoc());
|
contraction.getLoc());
|
||||||
auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
auto allRanges = getRanges(contraction);
|
||||||
getRanges(contraction));
|
auto loopRanges =
|
||||||
|
makeGenericLoopRanges(operandRangesToLoopsMap(contraction), allRanges);
|
||||||
|
|
||||||
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
|
SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
|
||||||
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
|
SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
|
||||||
|
@ -201,27 +192,33 @@ writeAsLoops(ContractionOp contraction) {
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
SmallVector<mlir::AffineForOp, 4> res;
|
// Return the AffineForOp for better compositionality (e.g. tiling).
|
||||||
res.reserve(pivs.size() + rivs.size());
|
SmallVector<mlir::AffineForOp, 4> loops;
|
||||||
|
loops.reserve(pivs.size() + rivs.size());
|
||||||
for (auto iv : parallelIvs)
|
for (auto iv : parallelIvs)
|
||||||
res.push_back(getForInductionVarOwner(iv.getValue()));
|
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
||||||
for (auto iv : reductionIvs)
|
for (auto iv : reductionIvs)
|
||||||
res.push_back(getForInductionVarOwner(iv.getValue()));
|
loops.push_back(getForInductionVarOwner(iv.getValue()));
|
||||||
return res;
|
|
||||||
|
return loops;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Optional<SmallVector<mlir::AffineForOp, 4>>
|
||||||
|
linalg::writeAsLoops(Operation *op) {
|
||||||
|
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
||||||
|
return writeContractionAsLoops(matmulOp);
|
||||||
|
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
||||||
|
return writeContractionAsLoops(matvecOp);
|
||||||
|
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
|
||||||
|
return writeContractionAsLoops(dotOp);
|
||||||
|
}
|
||||||
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
void linalg::lowerToLoops(mlir::Function *f) {
|
void linalg::lowerToLoops(mlir::Function *f) {
|
||||||
f->walk([](Operation *op) {
|
f->walk([](Operation *op) {
|
||||||
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
if (writeAsLoops(op))
|
||||||
writeAsLoops(matmulOp);
|
op->erase();
|
||||||
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
|
||||||
writeAsLoops(matvecOp);
|
|
||||||
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
|
|
||||||
writeAsLoops(dotOp);
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
op->erase();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
//===- Example.cpp - Our running example ----------------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// RUN: %p/test | FileCheck %s
|
||||||
|
|
||||||
|
#include "TestHarness.h"
|
||||||
|
#include "linalg1/Common.h"
|
||||||
|
#include "linalg2/Intrinsics.h"
|
||||||
|
#include "linalg3/Ops.h"
|
||||||
|
#include "linalg4/Transforms.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
|
||||||
|
using llvm::StringRef;
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
using namespace mlir::edsc::intrinsics;
|
||||||
|
using namespace linalg;
|
||||||
|
using namespace linalg::common;
|
||||||
|
using namespace linalg::intrinsics;
|
||||||
|
|
||||||
|
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||||
|
MLIRContext *context = module.getContext();
|
||||||
|
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||||
|
mlir::Function *f = linalg::common::makeFunction(
|
||||||
|
module, name,
|
||||||
|
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||||
|
|
||||||
|
ScopedContext scope(f);
|
||||||
|
// clang-format off
|
||||||
|
ValueHandle
|
||||||
|
M = dim(f->getArgument(0), 0),
|
||||||
|
N = dim(f->getArgument(2), 1),
|
||||||
|
K = dim(f->getArgument(0), 1),
|
||||||
|
rM = range(constant_index(0), M, constant_index(1)),
|
||||||
|
rN = range(constant_index(0), N, constant_index(1)),
|
||||||
|
rK = range(constant_index(0), K, constant_index(1)),
|
||||||
|
vA = view(f->getArgument(0), {rM, rK}),
|
||||||
|
vB = view(f->getArgument(1), {rK, rN}),
|
||||||
|
vC = view(f->getArgument(2), {rM, rN});
|
||||||
|
matmul(vA, vB, vC);
|
||||||
|
ret();
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_FUNC(matmul_tiled_loops) {
|
||||||
|
MLIRContext context;
|
||||||
|
Module module(&context);
|
||||||
|
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
|
||||||
|
lowerToTiledLoops(f, {8, 9});
|
||||||
|
PassManager pm;
|
||||||
|
pm.addPass(createLowerLinalgLoadStorePass());
|
||||||
|
if (succeeded(pm.run(f->getModule())))
|
||||||
|
cleanupAndPrintFunction(f);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
// CHECK-LABEL: func @matmul_tiled_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||||
|
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
|
||||||
|
// CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
|
||||||
|
// CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||||
|
// CHECK: affine.for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 8)(%i0)[%[[M]]] {
|
||||||
|
// CHECK: affine.for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 9)(%i1)[%[[N]]] {
|
||||||
|
// CHECK-NEXT: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
|
||||||
|
// CHECK-NEXT: %[[I3:.*]] = affine.apply (d0) -> (d0)(%i3)
|
||||||
|
// CHECK-NEXT: %[[I4:.*]] = affine.apply (d0) -> (d0)(%i4)
|
||||||
|
// CHECK-NEXT: %{{.*}} = load %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
|
||||||
|
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||||
|
// CHECK-NEXT: %[[I2:.*]] = affine.apply (d0) -> (d0)(%i2)
|
||||||
|
// CHECK-NEXT: %{{.*}} = load %arg1[%[[I2]], %[[I4]]] : memref<?x?xf32>
|
||||||
|
// CHECK-NEXT: %{{.*}} = load %arg0[%[[I3]], %[[I2]]] : memref<?x?xf32>
|
||||||
|
// CHECK-NEXT: %{{.*}} = mulf %10, %9 : f32
|
||||||
|
// CHECK-NEXT: %{{.*}} = addf %7, %11 : f32
|
||||||
|
// CHECK-NEXT: store %{{.*}}, %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_FUNC(matmul_tiled_views) {
|
||||||
|
MLIRContext context;
|
||||||
|
Module module(&context);
|
||||||
|
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
|
||||||
|
FuncBuilder b(f);
|
||||||
|
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
|
||||||
|
b.create<ConstantIndexOp>(f->getLoc(), 9)});
|
||||||
|
composeSliceOps(f);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
// CHECK-LABEL: func @matmul_tiled_views(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||||
|
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
|
||||||
|
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
|
||||||
|
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
|
||||||
|
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
|
||||||
|
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range">
|
||||||
|
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
|
||||||
|
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
|
||||||
|
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||||
|
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range">
|
||||||
|
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: linalg.matmul {%[[vA]], %[[vB]]} -> {%[[vC]]}
|
||||||
|
// clang-format on
|
||||||
|
cleanupAndPrintFunction(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_FUNC(matmul_tiled_views_as_loops) {
|
||||||
|
MLIRContext context;
|
||||||
|
Module module(&context);
|
||||||
|
mlir::Function *f =
|
||||||
|
makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
|
||||||
|
FuncBuilder b(f);
|
||||||
|
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
|
||||||
|
b.create<ConstantIndexOp>(f->getLoc(), 9)});
|
||||||
|
composeSliceOps(f);
|
||||||
|
lowerToLoops(f);
|
||||||
|
// This cannot lower below linalg.load and linalg.store due to lost
|
||||||
|
// information related to loop bounds and tiling. There are multiple ways to
|
||||||
|
// attack the problem, the best one is an IR change.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
// CHECK-LABEL: func @matmul_tiled_views_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||||
|
// CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
|
||||||
|
// CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
|
||||||
|
// CHECK-NEXT: %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
|
||||||
|
// CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
|
||||||
|
// CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range">
|
||||||
|
// CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
|
||||||
|
// CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
|
||||||
|
// CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
|
||||||
|
// CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range">
|
||||||
|
// CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) {
|
||||||
|
// CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) {
|
||||||
|
// CHECK-NEXT: affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) {
|
||||||
|
// CHECK-NEXT: %{{.*}} = cmpi "eq", %i4, %c0 : index
|
||||||
|
// CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %cst, %{{.*}} : f32
|
||||||
|
// CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view<f32xf32>">
|
||||||
|
// CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||||
|
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||||
|
// CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view<f32xf32>">
|
||||||
|
// clang-format on
|
||||||
|
cleanupAndPrintFunction(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
RUN_TESTS();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
//===- Transforms.h - Linalg dialect Transformations definition -----------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#ifndef LINALG4_TRANSFORMS_H_
|
||||||
|
#define LINALG4_TRANSFORMS_H_
|
||||||
|
|
||||||
|
#include "linalg3/Transforms.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
namespace linalg {
|
||||||
|
|
||||||
|
/// Rewrites a linalg `op` in tiled loop form and erases `op`.
|
||||||
|
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 8>>
|
||||||
|
writeAsTiledLoops(mlir::Operation *op, llvm::ArrayRef<uint64_t> tileSizes);
|
||||||
|
|
||||||
|
/// Rewrites a linalg `op` in tiled view form and erases `op`.
|
||||||
|
llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 8>>
|
||||||
|
writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes);
|
||||||
|
|
||||||
|
/// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function
|
||||||
|
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||||
|
/// in a function can generally not be specified.
|
||||||
|
void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef<uint64_t> tileSizes);
|
||||||
|
|
||||||
|
/// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function
|
||||||
|
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||||
|
/// in a function can generally not be specified.
|
||||||
|
void lowerToTiledViews(mlir::Function *f,
|
||||||
|
llvm::ArrayRef<mlir::Value *> tileSizes);
|
||||||
|
|
||||||
|
} // namespace linalg
|
||||||
|
|
||||||
|
#endif // LINALG4_TRANSFORMS_H_
|
|
@ -0,0 +1,214 @@
|
||||||
|
//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements analyses and transformations for the linalg dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "linalg4/Transforms.h"
|
||||||
|
#include "linalg3/Intrinsics.h"
|
||||||
|
#include "linalg3/TensorOps.h"
|
||||||
|
|
||||||
|
#include "mlir/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/EDSC/Helpers.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/Transforms/LoopUtils.h"
|
||||||
|
|
||||||
|
using llvm::ArrayRef;
|
||||||
|
using llvm::SmallVector;
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
using namespace linalg;
|
||||||
|
using namespace linalg::intrinsics;
|
||||||
|
|
||||||
|
llvm::Optional<SmallVector<mlir::AffineForOp, 8>>
|
||||||
|
linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) {
|
||||||
|
auto loops = writeAsLoops(op);
|
||||||
|
if (loops.hasValue())
|
||||||
|
return mlir::tile(*loops, tileSizes, loops->back());
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
void linalg::lowerToTiledLoops(mlir::Function *f,
|
||||||
|
ArrayRef<uint64_t> tileSizes) {
|
||||||
|
f->walk([tileSizes](Operation *op) {
|
||||||
|
if (writeAsTiledLoops(op, tileSizes).hasValue())
|
||||||
|
op->erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ConcreteOp>
|
||||||
|
static Operation::operand_range
|
||||||
|
getInputsAndOutputs(TensorContractionBase<ConcreteOp> &contraction) {
|
||||||
|
auto *inst = static_cast<ConcreteOp *>(&contraction)->getOperation();
|
||||||
|
auto begin = inst->operand_begin();
|
||||||
|
auto end = inst->operand_begin() + contraction.getNumInputs() +
|
||||||
|
contraction.getNumOutputs();
|
||||||
|
return {begin, end};
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isZeroIndex(Value *v) {
|
||||||
|
return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() &&
|
||||||
|
v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ConcreteOp>
|
||||||
|
static llvm::SmallVector<Value *, 4>
|
||||||
|
makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
|
||||||
|
ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
|
||||||
|
llvm::ArrayRef<Value *> tileSizes) {
|
||||||
|
assert(ivs.size() == tileSizes.size());
|
||||||
|
if (ivs.empty())
|
||||||
|
return RangeParts(allRanges).makeRanges();
|
||||||
|
|
||||||
|
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||||
|
RangeParts result(allRanges.size());
|
||||||
|
RangeParts rangeParts(allRanges);
|
||||||
|
|
||||||
|
for (auto map : op->loopsToOperandRangeMaps()) {
|
||||||
|
// 1. Take the first ivs results of the map, the other ones are not composed
|
||||||
|
// but merely copied over.
|
||||||
|
assert(map.getNumSymbols() == 0);
|
||||||
|
assert(map.getRangeSizes().empty());
|
||||||
|
MLIRContext *context = ScopedContext::getContext();
|
||||||
|
unsigned numParallel = op->getNumParallelDims();
|
||||||
|
unsigned numReduction = op->getNumReductionDims();
|
||||||
|
if (ivs.size() < numParallel + numReduction) {
|
||||||
|
// Inject zeros in positions that are not tiled.
|
||||||
|
SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
|
||||||
|
for (unsigned i = 0, e = numParallel + numReduction; i < e; ++i) {
|
||||||
|
dimReplacements[i] = (i < ivs.size())
|
||||||
|
? getAffineDimExpr(i, context)
|
||||||
|
: getAffineConstantExpr(0, context);
|
||||||
|
}
|
||||||
|
map = map.replaceDimsAndSymbols(dimReplacements, {}, ivs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Apply the rewritten map to the ranges.
|
||||||
|
unsigned numDims = map.getNumDims();
|
||||||
|
for (auto en : llvm::enumerate(map.getResults())) {
|
||||||
|
auto index = en.index();
|
||||||
|
auto expr = en.value();
|
||||||
|
AffineMap exprMap = AffineMap::get(numDims, 0, expr, {});
|
||||||
|
ValueHandle offset(makeFoldedComposedAffineApply(exprMap, ivs));
|
||||||
|
// Offset is normally a function of loop induction variables.
|
||||||
|
// If it is 0, it must come from a dimension that was not tiled.
|
||||||
|
if (isZeroIndex(offset)) {
|
||||||
|
result.mins.push_back(rangeParts.mins[index]);
|
||||||
|
result.maxes.push_back(rangeParts.maxes[index]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueHandle step(makeFoldedComposedAffineApply(exprMap, tileSizes));
|
||||||
|
ValueHandle min(rangeParts.mins[index]);
|
||||||
|
using edsc::op::operator+;
|
||||||
|
result.mins.push_back(min + offset);
|
||||||
|
// Ideally this should be:
|
||||||
|
// `min(rangeParts.max, rangeParts.min + offset + step)`
|
||||||
|
// but that breaks the current limitations of the affine dialect.
|
||||||
|
result.maxes.push_back(min + offset + step);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Note that for the purpose of tiled ranges and views, the steps do not
|
||||||
|
// change in our representation.
|
||||||
|
result.steps = rangeParts.steps;
|
||||||
|
|
||||||
|
return result.makeRanges();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ConcreteOp>
|
||||||
|
static SmallVector<Value *, 4>
|
||||||
|
makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
|
||||||
|
ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
|
||||||
|
auto tiledRanges =
|
||||||
|
makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
|
||||||
|
SmallVector<Value *, 4> res;
|
||||||
|
unsigned currentRange = 0;
|
||||||
|
for (auto *in : getInputsAndOutputs(contraction)) {
|
||||||
|
unsigned runningSliceDim = 0;
|
||||||
|
auto *runningSlice = in;
|
||||||
|
assert(runningSlice->getType().template isa<ViewType>());
|
||||||
|
for (unsigned d = 0, e = getViewRank(runningSlice); d < e; ++d) {
|
||||||
|
auto *r = tiledRanges[currentRange++];
|
||||||
|
runningSlice = slice(runningSlice, r, runningSliceDim++).getValue();
|
||||||
|
}
|
||||||
|
res.push_back(runningSlice);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ConcreteOp>
|
||||||
|
static SmallVector<mlir::AffineForOp, 8>
|
||||||
|
writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
|
||||||
|
ArrayRef<Value *> tileSizes) {
|
||||||
|
assert(tileSizes.size() <=
|
||||||
|
contraction.getNumParallelDims() + contraction.getNumReductionDims());
|
||||||
|
|
||||||
|
auto *op = static_cast<ConcreteOp *>(&contraction);
|
||||||
|
ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
|
||||||
|
SmallVector<IndexHandle, 4> ivs(tileSizes.size());
|
||||||
|
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
using linalg::common::LoopNestRangeBuilder;
|
||||||
|
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
||||||
|
getRanges(contraction), tileSizes);
|
||||||
|
linalg::common::LoopNestRangeBuilder(pivs, ranges)({
|
||||||
|
[&contraction, &tileSizes, &ivs]() {
|
||||||
|
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||||
|
auto views = makeTiledViews(contraction, ivValues, tileSizes);
|
||||||
|
ScopedContext::getBuilder()->create<ConcreteOp>(
|
||||||
|
ScopedContext::getLocation(), views);
|
||||||
|
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||||
|
return IndexHandle();
|
||||||
|
}()
|
||||||
|
});
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
SmallVector<mlir::AffineForOp, 8> res;
|
||||||
|
res.reserve(ivs.size());
|
||||||
|
for (auto iv : ivs)
|
||||||
|
res.push_back(getForInductionVarOwner(iv.getValue()));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Optional<SmallVector<mlir::AffineForOp, 8>>
|
||||||
|
linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) {
|
||||||
|
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
||||||
|
return writeContractionAsTiledViews(matmulOp, tileSizes);
|
||||||
|
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
||||||
|
return writeContractionAsTiledViews(matvecOp, tileSizes);
|
||||||
|
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
|
||||||
|
return writeContractionAsTiledViews(dotOp, tileSizes);
|
||||||
|
}
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) {
|
||||||
|
f->walk([tileSizes](Operation *op) {
|
||||||
|
if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
|
||||||
|
writeAsTiledViews(matmulOp, tileSizes);
|
||||||
|
} else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
|
||||||
|
writeAsTiledViews(matvecOp, tileSizes);
|
||||||
|
} else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
|
||||||
|
writeAsTiledViews(dotOp, tileSizes);
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
op->erase();
|
||||||
|
});
|
||||||
|
}
|
Loading…
Reference in New Issue