forked from OSchip/llvm-project
Fix Linalg lowering to loops
This CL makes lowering to loops always be a: ``` %D = linalg.dim %view, constant : !linalg.view<...> affine.for %ix = %c0 to %D { ... } ``` This form composes correctly with tiling and is also the proper way to emit loops from views that across function boundaries. The previous version that would extract the range_min/max/step was composing incorrectly with tiling (i.e. would shift by range_min both in the loop bounds and in the slice) and would not work across function boundaries. The relevant tests are updated and a new test `dot_view`---which lowers to loops from views passed as function parameters---is added. When additional context is available, the linalg.dim operations should be folded away but this is left for a future CL. -- PiperOrigin-RevId: 249634712
This commit is contained in:
parent
6a31f9a7e3
commit
c0f41e5bb3
|
@ -336,6 +336,10 @@ public:
|
|||
ArrayRef<Value *> operands) {
|
||||
return impl->create(builder, loc, operands);
|
||||
}
|
||||
Operation::operand_range getInputsAndOutputs() {
|
||||
auto range = this->getOperation()->getOperands();
|
||||
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
|
||||
}
|
||||
|
||||
private:
|
||||
struct Concept {
|
||||
|
|
|
@ -89,8 +89,16 @@ Value *createOrReturnView(FuncBuilder *b, Location loc,
|
|||
enum class RangePart { Min = 0, Max, Step };
|
||||
Value *extractRangePart(Value *range, RangePart part);
|
||||
|
||||
/// Returns the values obtained by applying `map` to the list of values.
|
||||
/// Performs simplifications and foldings where possible.
|
||||
SmallVector<Value *, 4> applyMapToValues(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> values,
|
||||
FunctionConstants &state);
|
||||
|
||||
/// Returns the values obtained by applying `map` to the list of range parts
|
||||
/// extracted from `ranges`.
|
||||
/// extracted from `ranges`. Performs simplifications and foldings where
|
||||
/// possible.
|
||||
SmallVector<Value *, 4> applyMapToRangePart(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> ranges,
|
||||
|
|
|
@ -24,49 +24,55 @@
|
|||
#include "mlir/Linalg/Passes.h"
|
||||
#include "mlir/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::edsc;
|
||||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
using namespace llvm;
|
||||
|
||||
// Creates a number of ranges equal to the number of results in `map`.
|
||||
// The returned ranges correspond to the loop ranges, in the proper order, for
|
||||
// which new loops will be created.
|
||||
static SmallVector<Value *, 4> makeLoopRanges(FuncBuilder *b, Location loc,
|
||||
static SmallVector<Value *, 4> emitLoopRanges(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> allOpRanges,
|
||||
ArrayRef<Value *> allViewSizes,
|
||||
FunctionConstants &state) {
|
||||
// Apply `map` to get mins/maxes/steps in loop order.
|
||||
auto mins =
|
||||
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Min, state);
|
||||
auto maxes =
|
||||
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Max, state);
|
||||
auto steps =
|
||||
applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Step, state);
|
||||
|
||||
// Apply `map` to get view sizes in loop order.
|
||||
auto sizes = applyMapToValues(b, loc, map, allViewSizes, state);
|
||||
// Create a new range with the applied tile sizes.
|
||||
SmallVector<Value *, 4> res;
|
||||
for (unsigned idx = 0, e = steps.size(); idx < e; ++idx)
|
||||
res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], steps[idx]));
|
||||
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
|
||||
res.push_back(b->create<RangeOp>(loc, state.getOrCreateIndex(0), sizes[idx],
|
||||
state.getOrCreateIndex(1)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Returns the linearized list of all view dimensions in a linalgOp. Appliying
|
||||
// the inverse, concatenated loopToOperandRangeMaps to this list allows the
|
||||
// derivation of loop ranges for any linalgOp.
|
||||
static SmallVector<Value *, 8> getViewSizes(LinalgOp &linalgOp) {
|
||||
SmallVector<Value *, 8> res;
|
||||
using dim = ValueBuilder<linalg::DimOp>;
|
||||
for (auto v : linalgOp.getInputsAndOutputs()) {
|
||||
ViewType t = v->getType().cast<ViewType>();
|
||||
for (unsigned i = 0; i < t.getRank(); ++i)
|
||||
res.push_back(dim(v, i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) {
|
||||
FuncBuilder b(linalgOp.getOperation());
|
||||
ScopedContext scope(b, linalgOp.getOperation()->getLoc());
|
||||
auto loopRanges = makeLoopRanges(
|
||||
auto loopRanges = emitLoopRanges(
|
||||
scope.getBuilder(), scope.getLocation(),
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))),
|
||||
getRanges(linalgOp.getOperation()), state);
|
||||
getViewSizes(linalgOp), state);
|
||||
|
||||
SmallVector<IndexHandle, 4> parallelIvs(linalgOp.getNumParallelLoops());
|
||||
SmallVector<IndexHandle, 4> reductionIvs(linalgOp.getNumReductionLoops());
|
||||
|
|
|
@ -167,16 +167,10 @@ static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc,
|
|||
return b->create<AffineApplyOp>(loc, map, operands);
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> ranges,
|
||||
RangePart part,
|
||||
FunctionConstants &state) {
|
||||
SmallVector<Value *, 4> rangeParts(ranges.size());
|
||||
|
||||
llvm::transform(ranges, rangeParts.begin(),
|
||||
[&](Value *range) { return extractRangePart(range, part); });
|
||||
|
||||
SmallVector<Value *, 4> mlir::applyMapToValues(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> values,
|
||||
FunctionConstants &state) {
|
||||
SmallVector<Value *, 4> res;
|
||||
res.reserve(map.getNumResults());
|
||||
unsigned numDims = map.getNumDims();
|
||||
|
@ -185,12 +179,22 @@ SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
|
|||
// folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
|
||||
for (auto expr : map.getResults()) {
|
||||
AffineMap map = AffineMap::get(numDims, 0, expr, {});
|
||||
res.push_back(
|
||||
emitOrFoldComposedAffineApply(b, loc, map, rangeParts, state));
|
||||
res.push_back(emitOrFoldComposedAffineApply(b, loc, map, values, state));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Value *, 4> mlir::applyMapToRangePart(FuncBuilder *b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<Value *> ranges,
|
||||
RangePart part,
|
||||
FunctionConstants &state) {
|
||||
SmallVector<Value *, 4> rangeParts(ranges.size());
|
||||
llvm::transform(ranges, rangeParts.begin(),
|
||||
[&](Value *range) { return extractRangePart(range, part); });
|
||||
return applyMapToValues(b, loc, map, rangeParts, state);
|
||||
}
|
||||
|
||||
Value *FunctionConstants::getOrCreateIndex(int64_t v) {
|
||||
auto it = map.find(v);
|
||||
if (it != map.end())
|
||||
|
|
|
@ -18,14 +18,17 @@ func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: inde
|
|||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
|
||||
// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) {
|
||||
// CHECK: affine.for %i2 = #[[ID]](%c0) to #[[ID]](%arg3) {
|
||||
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) {
|
||||
// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[N]]) {
|
||||
// CHECK: affine.for %i2 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
|
||||
// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0, %i2] : !linalg.view<?x?xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i2, %i1] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK: %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
|
||||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: linalg.store %[[res]], %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
|
||||
|
||||
func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
|
@ -43,13 +46,15 @@ func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: inde
|
|||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
|
||||
// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%arg2) {
|
||||
// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[M]]) {
|
||||
// CHECK: affine.for %i1 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
|
||||
// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0, %i1] : !linalg.view<?x?xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i1] : !linalg.view<?xf32>
|
||||
// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK: %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view<?xf32>
|
||||
// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%i0] : !linalg.view<?xf32>
|
||||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: linalg.store %[[res]], %[[C]][%i0] : !linalg.view<?xf32>
|
||||
|
||||
func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
|
@ -66,10 +71,25 @@ func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index)
|
|||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view<f32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%arg1) {
|
||||
// CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
|
||||
// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%i0] : !linalg.view<?xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%i0] : !linalg.view<?xf32>
|
||||
// CHECK: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK: %[[c:.*]] = linalg.load %[[C]][] : !linalg.view<f32>
|
||||
// CHECK: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: linalg.store %[[res]], %[[C]][] : !linalg.view<f32>
|
||||
// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][] : !linalg.view<f32>
|
||||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: linalg.store %[[res]], %[[C]][] : !linalg.view<f32>
|
||||
|
||||
func @dot_view(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
|
||||
linalg.dot(%arg0, %arg1, %arg2) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot_view(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<f32>) {
|
||||
// CHECK: %[[K:.*]] = linalg.dim %arg0, 0 : !linalg.view<?xf32>
|
||||
// CHECK: affine.for %i0 = #[[ID]](%c0) to #[[ID]](%[[K]]) {
|
||||
// CHECK-DAG: %[[a:.*]] = linalg.load %arg0[%i0] : !linalg.view<?xf32>
|
||||
// CHECK-DAG: %[[b:.*]] = linalg.load %arg1[%i0] : !linalg.view<?xf32>
|
||||
// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
|
||||
// CHECK-DAG: %[[c:.*]] = linalg.load %arg2[] : !linalg.view<f32>
|
||||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: linalg.store %[[res]], %arg2[] : !linalg.view<f32>
|
||||
|
|
Loading…
Reference in New Issue