NFC: Move AffineApplyOp to the AffineOps dialect. This also moves the isValidDim/isValidSymbol methods from Value to the AffineOps dialect.

PiperOrigin-RevId: 232386632
This commit is contained in:
River Riddle 2019-02-04 16:15:13 -08:00 committed by jpienaar
parent 0f50414fa4
commit c9ad4621ce
13 changed files with 310 additions and 316 deletions

View File

@ -35,6 +35,55 @@ public:
AffineOpsDialect(MLIRContext *context);
};
/// The "affine_apply" operation applies an affine map to a list of operands,
/// yielding a single result. The operand list must be the same size as the
/// number of arguments to the affine mapping. All operands and the result are
/// of type 'Index'. This operation requires a single affine map attribute named
/// "map". For example:
///
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
/// (index) -> (index)
///
/// equivalently:
///
/// #map42 = (d0)->(d0+1)
/// %y = affine_apply #map42(%x)
///
class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
/// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<Value *> operands);
/// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map").getValue();
}
/// Returns true if the result of this operation can be used as dimension id.
bool isValidDim() const;
/// Returns true if the result of this operation is a symbol.
bool isValidSymbol() const;
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
private:
friend class Instruction;
explicit AffineApplyOp(const Instruction *state) : Op(state) {}
};
/// The "for" instruction represents an affine loop nest, defining an SSA value
/// for its induction variable. The induction variable is represented as a
/// BlockArgument to the entry block of the body. The body and induction
@ -301,6 +350,18 @@ private:
explicit AffineIfOp(const Instruction *state) : Op(state) {}
};
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim(const Value *value);
/// Returns true if the given Value can be used as a symbol.
bool isValidSymbol(const Value *value);
/// Modifies both `map` and `operands` in-place so as to:
/// 1. drop duplicate operands
/// 2. drop unused dims and symbols from map
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
} // end namespace mlir
#endif

View File

@ -36,55 +36,6 @@ public:
BuiltinDialect(MLIRContext *context);
};
/// The "affine_apply" operation applies an affine map to a list of operands,
/// yielding a single result. The operand list must be the same size as the
/// number of arguments to the affine mapping. All operands and the result are
/// of type 'Index'. This operation requires a single affine map attribute named
/// "map". For example:
///
/// %y = "affine_apply" (%x) { map: (d0) -> (d0 + 1) } :
/// (index) -> (index)
///
/// equivalently:
///
/// #map42 = (d0)->(d0+1)
/// %y = affine_apply #map42(%x)
///
class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
/// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<Value *> operands);
/// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map").getValue();
}
/// Returns true if the result of this operation can be used as dimension id.
bool isValidDim() const;
/// Returns true if the result of this operation is a symbol.
bool isValidSymbol() const;
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
private:
friend class Instruction;
explicit AffineApplyOp(const Instruction *state) : Op(state) {}
};
/// The "br" operation represents a branch instruction in a CFG function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types for each successor must match the
@ -397,12 +348,6 @@ bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
/// Modifies both `map` and `operands` in-place so as to:
/// 1. drop duplicate operands
/// 2. drop unused dims and symbols from map
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
} // end namespace mlir
#endif

View File

@ -59,15 +59,6 @@ public:
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// TODO: move isValidDim/isValidSymbol to a utility library specific to the
/// polyhedral operations.
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim() const;
/// Returns true if the given Value can be used as a symbol.
bool isValidSymbol() const;
/// Return the function that this Value is defined in.
Function *getFunction();

View File

@ -32,6 +32,7 @@
namespace mlir {
class AffineApplyOp;
class AffineForOp;
class FuncBuilder;
class Location;

View File

@ -22,6 +22,8 @@
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@ -30,7 +32,241 @@ using namespace mlir;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineForOp, AffineIfOp>();
addOperations<AffineApplyOp, AffineForOp, AffineIfOp>();
}
// Value can be used as a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool mlir::isValidDim(const Value *value) {
if (auto *inst = value->getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidDim();
return false;
}
// This value is a block argument.
return true;
}
// Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool mlir::isValidSymbol(const Value *value) {
if (auto *inst = value->getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidSymbol();
return false;
}
// Otherwise, the only valid symbol is a non induction variable block
// argument.
return !isForInductionVar(value);
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap map, ArrayRef<Value *> operands) {
result->addOperands(operands);
result->types.append(map.getNumResults(), builder->getIndexType());
result->addAttribute("map", builder->getAffineMapAttr(map));
}
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
auto map = mapAttr.getValue();
if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result->operands.size()) {
return parser->emitError(parser->getNameLoc(),
"dimension or symbol index mismatch");
}
result->types.append(map.getNumResults(), affineIntTy);
return false;
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto map = getAffineMap();
*p << "affine_apply " << map;
printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
bool AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return emitOpError("requires an affine map");
// Check input and output dimensions match.
auto map = affineMapAttr.getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
return emitOpError(
"operand count and affine map dimension and symbol count must match");
// Verify that result count matches affine map result count.
if (map.getNumResults() != 1)
return emitOpError("mapping must produce one value");
return false;
}
// The result of the affine apply operation can be used as a dimension id if it
// is a CFG value or if it is an Value, and all the operands are valid
// dimension ids.
bool AffineApplyOp::isValidDim() const {
return llvm::all_of(getOperands(),
[](const Value *op) { return mlir::isValidDim(op); });
}
// The result of the affine apply operation can be used as a symbol if it is
// a CFG value or if it is an Value, and all the operands are symbols.
bool AffineApplyOp::isValidSymbol() const {
return llvm::all_of(getOperands(),
[](const Value *op) { return mlir::isValidSymbol(op); });
}
Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
auto map = getAffineMap();
SmallVector<Attribute, 1> result;
if (map.constantFold(operands, result))
return Attribute();
return result[0];
}
namespace {
/// SimplifyAffineApply operations.
///
struct SimplifyAffineApply : public RewritePattern {
SimplifyAffineApply(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override;
void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace.
namespace {
/// FIXME: this is massive overkill for simple obviously always matching
/// canonicalizations. Fix the pattern rewriter to make this easy.
struct SimplifyAffineApplyState : public PatternState {
AffineMap map;
SmallVector<Value *, 8> operands;
SimplifyAffineApplyState(AffineMap map,
const SmallVector<Value *, 8> &operands)
: map(map), operands(operands) {}
};
} // end anonymous namespace.
void mlir::canonicalizeMapAndOperands(
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
if (!map || operands->empty())
return;
assert(map->getNumInputs() == operands->size() &&
"map inputs must match number of operands");
// Check to see what dims are used.
llvm::SmallBitVector usedDims(map->getNumDims());
llvm::SmallBitVector usedSyms(map->getNumSymbols());
map->walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims[dimExpr.getPosition()] = true;
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSyms[symExpr.getPosition()] = true;
});
auto *context = map->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
unsigned nextDim = 0;
for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
if (usedDims[i]) {
auto it = seenDims.find((*operands)[i]);
if (it == seenDims.end()) {
dimRemapping[i] = getAffineDimExpr(nextDim++, context);
resultOperands.push_back((*operands)[i]);
seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
} else {
dimRemapping[i] = it->second;
}
}
}
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
unsigned nextSym = 0;
for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
if (usedSyms[i]) {
auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
if (it == seenSymbols.end()) {
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
resultOperands.push_back((*operands)[i + map->getNumDims()]);
seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()],
symRemapping[i]));
} else {
symRemapping[i] = it->second;
}
}
}
*map =
map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
*operands = resultOperands;
}
PatternMatchResult SimplifyAffineApply::match(Instruction *op) const {
auto apply = op->cast<AffineApplyOp>();
auto map = apply->getAffineMap();
AffineMap oldMap = map;
SmallVector<Value *, 8> resultOperands(apply->getOperands().begin(),
apply->getOperands().end());
canonicalizeMapAndOperands(&map, &resultOperands);
if (map != oldMap)
return matchSuccess(
std::make_unique<SimplifyAffineApplyState>(map, resultOperands));
return matchFailure();
}
void SimplifyAffineApply::rewrite(Instruction *op,
std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
auto *applyState = static_cast<SimplifyAffineApplyState *>(state.get());
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, applyState->map,
applyState->operands);
}
void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(std::make_unique<SimplifyAffineApply>(context));
}
//===----------------------------------------------------------------------===//
@ -493,9 +729,9 @@ bool AffineIfOp::verify() const {
IntegerSet condition = conditionAttr.getValue();
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
const Value *operand = getOperand(i);
if (i < condition.getNumDims() && !operand->isValidDim())
if (i < condition.getNumDims() && !isValidDim(operand))
return emitOpError("operand cannot be used as a dimension id");
if (i >= condition.getNumDims() && !operand->isValidSymbol())
if (i >= condition.getNumDims() && !isValidSymbol(operand))
return emitOpError("operand cannot be used as a symbol");
}

View File

@ -681,7 +681,7 @@ static void buildDimAndSymbolPositionMaps(
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isForInductionVar(values[i])) {
assert(values[i]->isValidSymbol() &&
assert(isValidSymbol(values[i]) &&
"access operand has to be either a loop IV or a symbol");
valuePosMap->addSymbolValue(value);
} else if (isSrc) {
@ -743,7 +743,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
auto setSymbolIds = [&](ArrayRef<Value *> values) {
for (auto *value : values) {
if (!isForInductionVar(value)) {
assert(value->isValidSymbol() && "expected symbol");
assert(isValidSymbol(value) && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
}
}
@ -913,7 +913,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
if (isForInductionVar(operands[i]))
continue;
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto *opInst = symbol->getDefiningInst()) {
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {

View File

@ -1268,7 +1268,7 @@ bool FlatAffineConstraints::addAffineForOpDomain(
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc)) {
if (operand->isValidSymbol()) {
if (isValidSymbol(operand)) {
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.

View File

@ -168,7 +168,7 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth) {
} else {
// Has to be a valid symbol.
auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol());
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto *inst = symbol->getDefiningInst()) {
if (auto constOp = inst->dyn_cast<ConstantIndexOp>()) {

View File

@ -27,7 +27,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -38,7 +37,7 @@ using namespace mlir;
BuiltinDialect::BuiltinDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
addOperations<BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
addTypes<FunctionType, IndexType, UnknownType, FloatType, IntegerType,
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
}
@ -78,211 +77,6 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
return false;
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap map, ArrayRef<Value *> operands) {
result->addOperands(operands);
result->types.append(map.getNumResults(), builder->getIndexType());
result->addAttribute("map", builder->getAffineMapAttr(map));
}
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
auto map = mapAttr.getValue();
if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result->operands.size()) {
return parser->emitError(parser->getNameLoc(),
"dimension or symbol index mismatch");
}
result->types.append(map.getNumResults(), affineIntTy);
return false;
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto map = getAffineMap();
*p << "affine_apply " << map;
printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
bool AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return emitOpError("requires an affine map");
// Check input and output dimensions match.
auto map = affineMapAttr.getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
return emitOpError(
"operand count and affine map dimension and symbol count must match");
// Verify that result count matches affine map result count.
if (map.getNumResults() != 1)
return emitOpError("mapping must produce one value");
return false;
}
// The result of the affine apply operation can be used as a dimension id if it
// is a CFG value or if it is an Value, and all the operands are valid
// dimension ids.
bool AffineApplyOp::isValidDim() const {
for (auto *op : getOperands()) {
if (!op->isValidDim())
return false;
}
return true;
}
// The result of the affine apply operation can be used as a symbol if it is
// a CFG value or if it is an Value, and all the operands are symbols.
bool AffineApplyOp::isValidSymbol() const {
for (auto *op : getOperands()) {
if (!op->isValidSymbol())
return false;
}
return true;
}
Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
auto map = getAffineMap();
SmallVector<Attribute, 1> result;
if (map.constantFold(operands, result))
return Attribute();
return result[0];
}
namespace {
/// SimplifyAffineApply operations.
///
struct SimplifyAffineApply : public RewritePattern {
SimplifyAffineApply(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override;
void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace.
namespace {
/// FIXME: this is massive overkill for simple obviously always matching
/// canonicalizations. Fix the pattern rewriter to make this easy.
struct SimplifyAffineApplyState : public PatternState {
AffineMap map;
SmallVector<Value *, 8> operands;
SimplifyAffineApplyState(AffineMap map,
const SmallVector<Value *, 8> &operands)
: map(map), operands(operands) {}
};
} // end anonymous namespace.
void mlir::canonicalizeMapAndOperands(
AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
if (!map || operands->empty())
return;
assert(map->getNumInputs() == operands->size() &&
"map inputs must match number of operands");
// Check to see what dims are used.
llvm::SmallBitVector usedDims(map->getNumDims());
llvm::SmallBitVector usedSyms(map->getNumSymbols());
map->walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims[dimExpr.getPosition()] = true;
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSyms[symExpr.getPosition()] = true;
});
auto *context = map->getContext();
SmallVector<Value *, 8> resultOperands;
resultOperands.reserve(operands->size());
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
unsigned nextDim = 0;
for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
if (usedDims[i]) {
auto it = seenDims.find((*operands)[i]);
if (it == seenDims.end()) {
dimRemapping[i] = getAffineDimExpr(nextDim++, context);
resultOperands.push_back((*operands)[i]);
seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
} else {
dimRemapping[i] = it->second;
}
}
}
llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
unsigned nextSym = 0;
for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
if (usedSyms[i]) {
auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
if (it == seenSymbols.end()) {
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
resultOperands.push_back((*operands)[i + map->getNumDims()]);
seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()],
symRemapping[i]));
} else {
symRemapping[i] = it->second;
}
}
}
*map =
map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
*operands = resultOperands;
}
PatternMatchResult SimplifyAffineApply::match(Instruction *op) const {
auto apply = op->cast<AffineApplyOp>();
auto map = apply->getAffineMap();
AffineMap oldMap = map;
SmallVector<Value *, 8> resultOperands(apply->getOperands().begin(),
apply->getOperands().end());
canonicalizeMapAndOperands(&map, &resultOperands);
if (map != oldMap)
return matchSuccess(
std::make_unique<SimplifyAffineApplyState>(map, resultOperands));
return matchFailure();
}
void SimplifyAffineApply::rewrite(Instruction *op,
std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const {
auto *applyState = static_cast<SimplifyAffineApplyState *>(state.get());
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, applyState->map,
applyState->operands);
}
void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(std::make_unique<SimplifyAffineApply>(context));
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//

View File

@ -300,42 +300,6 @@ Function *Instruction::getFunction() const {
return block ? block->getFunction() : nullptr;
}
// Value can be used as a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool Value::isValidDim() const {
if (auto *inst = getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidDim();
return false;
}
// This value is either a function argument or an induction variable. Both
// are ok.
return true;
}
// Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool Value::isValidSymbol() const {
if (auto *inst = getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidSymbol();
return false;
}
// Otherwise, the only valid symbol is a function argument.
auto *arg = dyn_cast<BlockArgument>(this);
return arg && arg->isFunctionArgument();
}
/// Emit a note about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
void Instruction::emitNote(const Twine &message) const {

View File

@ -21,6 +21,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"

View File

@ -124,7 +124,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
for (auto *extraIndex : extraIndices) {
assert(extraIndex->getDefiningInst()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) &&
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
"invalid memory op index");
state.operands.push_back(extraIndex);
}

View File

@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"