Introduce Deaffinator pass.

This function pass replaces affine_apply operations in CFG functions with
sequences of primitive arithmetic instructions that form the affine map.

The actual replacement functionality is located in LoweringUtils as a
standalone function operating on an individual affine_apply operation and
inserting the result at the location of the original operation.  It is expected
to be useful for other, target-specific lowering passes that may start at
MLFunction level that Deaffinator does not support.

PiperOrigin-RevId: 222406692
This commit is contained in:
Alex Zinenko 2018-11-21 07:42:16 -08:00 committed by jpienaar
parent ac6bfa6780
commit 615c41c788
5 changed files with 344 additions and 0 deletions

View File

@ -0,0 +1,37 @@
//===- LoweringUtils.h ---- Utilities for Lowering Passes -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements miscellaneous utility functions for lowering passes.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INCLUDE_MLIR_TRANSFORMS_LOWERINGUTILS_H
#define MLIR_INCLUDE_MLIR_TRANSFORMS_LOWERINGUTILS_H
namespace mlir {
class AffineApplyOp;
/// Convert the affine_apply operation `op` into a sequence of primitive
/// arithmetic instructions that have the same effect and insert them at the
/// current location of the `op`. Erase the `op` from its parent. Return true
/// if any errors happened during expansion.
bool expandAffineApply(AffineApplyOp &op);
} // namespace mlir
#endif // MLIR_INCLUDE_MLIR_TRANSFORMS_LOWERINGUTILS_H

View File

@ -80,6 +80,12 @@ FunctionPass *createDmaGenerationPass(unsigned lowMemorySpace,
unsigned highMemorySpace,
int minDmaTransferSize = 1024);
/// Replaces affine_apply operations in CFGFunctions with the arithmetic
/// primitives (addition, multplication) they comprise. Errors out on
/// MLFunctions since they may contain affine_applies baked into the For loop
/// bounds that cannot be replaced.
FunctionPass *createDeaffinatorPass();
} // end namespace mlir
#endif // MLIR_TRANSFORMS_PASSES_H

View File

@ -0,0 +1,78 @@
//===- Deaffinator.cpp - Convert affine_apply to primitives -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines an MLIR function pass that replaces affine_apply operations
// in CFGFunctions with sequences of corresponding elementary arithmetic
// operations.
//
//===----------------------------------------------------------------------===//
//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoweringUtils.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
struct Deaffinator : public FunctionPass {
explicit Deaffinator() : FunctionPass(&Deaffinator::passID) {}
PassResult runOnMLFunction(MLFunction *f) override;
PassResult runOnCFGFunction(CFGFunction *f) override;
static char passID;
};
} // end anonymous namespace
char Deaffinator::passID = 0;
PassResult Deaffinator::runOnMLFunction(MLFunction *f) {
f->emitError("ML Functions contain syntactically hidden affine_apply's that "
"cannot be expanded");
return failure();
}
PassResult Deaffinator::runOnCFGFunction(CFGFunction *f) {
for (BasicBlock &bb : *f) {
// Handle iterators with care because we erase in the same loop.
// In particular, step to the next element before erasing the current one.
for (auto it = bb.begin(); it != bb.end();) {
Instruction &inst = *it;
auto affineApplyOp = inst.dyn_cast<AffineApplyOp>();
++it;
if (!affineApplyOp)
continue;
if (expandAffineApply(*affineApplyOp))
return failure();
}
}
return success();
}
static PassRegistration<Deaffinator>
pass("deaffinator", "Decompose affine_applies into primitive operations");
FunctionPass *createDeaffinatorPass() { return new Deaffinator(); }

View File

@ -0,0 +1,136 @@
//===- LoweringUtils.cpp - Utilities for Lowering Passes ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements utility functions for lowering passes, for example
// lowering affine_apply operations to individual components.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LoweringUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
namespace {
// Visit affine expressions recursively and build the sequence of instructions
// that correspond to it. Visitation functions return an SSAValue of the
// expression subtree they visited or `nullptr` on error.
class AffineApplyExpander
: public AffineExprVisitor<AffineApplyExpander, SSAValue *> {
public:
// This must take AffineApplyOp by non-const reference because it needs
// non-const SSAValue pointers for arguments; it is not supposed to actually
// modify the op. Non-const SSAValues are required by the BinaryOp builders.
AffineApplyExpander(FuncBuilder &builder, AffineApplyOp &op)
: builder(builder), applyOp(op), loc(op.getLoc()) {}
template <typename OpTy> SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
auto op = builder.create<OpTy>(loc, lhs, rhs);
return op->getResult();
}
SSAValue *visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<AddIOp>(expr);
}
SSAValue *visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<MulIOp>(expr);
}
// TODO(zinenko): implement when the standard operators are made available.
SSAValue *visitModExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc, "unsupported binary operator: mod");
return nullptr;
}
SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc,
"unsupported binary operator: floor_div");
return nullptr;
}
SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc,
"unsupported binary operator: ceil_div");
return nullptr;
}
SSAValue *visitConstantExpr(AffineConstantExpr expr) {
auto valueAttr =
builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
auto op =
builder.create<ConstantOp>(loc, valueAttr, builder.getIndexType());
return op->getResult();
}
SSAValue *visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < applyOp.getNumOperands() &&
"affine dim position out of range");
// FIXME: this assumes a certain order of AffineApplyOp operands, the
// cleaner interface would be to separate them at the op level.
return applyOp.getOperand(expr.getPosition());
}
SSAValue *visitSymbolExpr(AffineSymbolExpr expr) {
// FIXME: this assumes a certain order of AffineApplyOp operands, the
// cleaner interface would be to separate them at the op level.
assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() <
applyOp.getNumOperands() &&
"symbol dim position out of range");
return applyOp.getOperand(expr.getPosition() +
applyOp.getAffineMap().getNumDims());
}
private:
FuncBuilder &builder;
AffineApplyOp &applyOp;
Location loc;
};
} // namespace
// Given an affine expression `expr` extracted from `op`, build the sequence of
// primitive instructions that correspond to the affine expression in the
// `builder`.
static SSAValue *expandAffineExpr(FuncBuilder &builder, const AffineExpr &expr,
AffineApplyOp &op) {
auto expander = AffineApplyExpander(builder, op);
return expander.visit(expr);
}
bool mlir::expandAffineApply(AffineApplyOp &op) {
FuncBuilder builder(op.getOperation());
builder.setInsertionPoint(op.getOperation());
auto affineMap = op.getAffineMap();
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
SSAValue *expanded = expandAffineExpr(builder, numberedExpr.value(), op);
if (!expanded)
return true;
op.getResult(numberedExpr.index())->replaceAllUsesWith(expanded);
}
op.erase();
return false;
}

View File

@ -0,0 +1,87 @@
// RUN: mlir-opt -deaffinator %s | FileCheck %s
#map0 = () -> (0)
#map1 = ()[s0] -> (s0)
#map2 = (d0) -> (d0)
#map3 = (d0)[s0] -> (d0 + s0 + 1)
#map4 = (d0,d1,d2,d3)[s0,s1,s2] -> (d0 + 2*d1 + 3*d2 + 4*d3 + 5*s0 + 6*s1 + 7*s2)
#map5 = (d0,d1,d2) -> (d0,d1,d2)
#map6 = (d0,d1,d2) -> (d0 + d1 + d2)
// CHECK-LABEL: cfgfunc @affine_applies()
cfgfunc @affine_applies() {
bb0:
// CHECK: %c0 = constant 0 : index
%zero = affine_apply #map0()
// Identity maps are just discarded.
// CHECK-NEXT: %c101 = constant 101 : index
%101 = constant 101 : index
%symbZero = affine_apply #map1()[%zero]
// CHECK-NEXT: %c102 = constant 102 : index
%102 = constant 102 : index
%copy = affine_apply #map2(%zero)
// CHECK-NEXT: %0 = addi %c0, %c0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %1 = addi %0, %c1 : index
%one = affine_apply #map3(%symbZero)[%zero]
// CHECK-NEXT: %c103 = constant 103 : index
// CHECK-NEXT: %c104 = constant 104 : index
// CHECK-NEXT: %c105 = constant 105 : index
// CHECK-NEXT: %c106 = constant 106 : index
// CHECK-NEXT: %c107 = constant 107 : index
// CHECK-NEXT: %c108 = constant 108 : index
// CHECK-NEXT: %c109 = constant 109 : index
%103 = constant 103 : index
%104 = constant 104 : index
%105 = constant 105 : index
%106 = constant 106 : index
%107 = constant 107 : index
%108 = constant 108 : index
%109 = constant 109 : index
// CHECK-NEXT: %c2 = constant 2 : index
// CHECK-NEXT: %2 = muli %c104, %c2 : index
// CHECK-NEXT: %3 = addi %c103, %2 : index
// CHECK-NEXT: %c3 = constant 3 : index
// CHECK-NEXT: %4 = muli %c105, %c3 : index
// CHECK-NEXT: %5 = addi %3, %4 : index
// CHECK-NEXT: %c4 = constant 4 : index
// CHECK-NEXT: %6 = muli %c106, %c4 : index
// CHECK-NEXT: %7 = addi %5, %6 : index
// CHECK-NEXT: %c5 = constant 5 : index
// CHECK-NEXT: %8 = muli %c107, %c5 : index
// CHECK-NEXT: %9 = addi %7, %8 : index
// CHECK-NEXT: %c6 = constant 6 : index
// CHECK-NEXT: %10 = muli %c108, %c6 : index
// CHECK-NEXT: %11 = addi %9, %10 : index
// CHECK-NEXT: %c7 = constant 7 : index
// CHECK-NEXT: %12 = muli %c109, %c7 : index
// CHECK-NEXT: %13 = addi %11, %12 : index
%four = affine_apply #map4(%103,%104,%105,%106)[%107,%108,%109]
return
}
// CHECK-LABEL: cfgfunc @multiresult_affine_apply()
cfgfunc @multiresult_affine_apply() {
// CHECK: bb0
bb0:
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: %0 = addi %c1, %c1 : index
// CHECK-NEXT: %1 = addi %0, %c1 : index
%one = constant 1 : index
%tuple = affine_apply #map5 (%one, %one, %one)
%three = affine_apply #map6 (%tuple#0, %tuple#1, %tuple#2)
return
}
// CHECK-LABEL: cfgfunc @args_ret_affine_apply
cfgfunc @args_ret_affine_apply(index, index) -> (index, index) {
// CHECK: bb0(%arg0: index, %arg1: index):
bb0(%0 : index, %1 : index):
// CHECK-NEXT: return %arg0, %arg1 : index, index
%00 = affine_apply #map2 (%0)
%11 = affine_apply #map1 ()[%1]
return %00, %11 : index, index
}