Recommit: Define a AffineOps dialect as well as an AffineIfOp operation. Replace all instances of IfInst with AffineIfOp and delete IfInst.

PiperOrigin-RevId: 231342063
This commit is contained in:
River Riddle 2019-01-28 21:23:53 -08:00 committed by jpienaar
parent 39d81f246a
commit 755538328b
36 changed files with 495 additions and 542 deletions

View File

@ -0,0 +1,91 @@
//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines convenience types for working with Affine operations
// in the MLIR instruction set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_AFFINEOPS_AFFINEOPS_H
#define MLIR_AFFINEOPS_AFFINEOPS_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class AffineOpsDialect : public Dialect {
public:
AffineOpsDialect(MLIRContext *context);
};
/// The "if" operation represents an ifthenelse construct for conditionally
/// executing two regions of code. The operands to an if operation are an
/// IntegerSet condition and a set of symbol/dimension operands to the
/// condition set. The operation produces no results. For example:
///
/// if #set(%i) {
/// ...
/// } else {
/// ...
/// }
///
/// The 'else' blocks to the if operation are optional, and may be omitted. For
/// example:
///
/// if #set(%i) {
/// ...
/// }
///
class AffineIfOp
: public Op<AffineIfOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
IntegerSet condition, ArrayRef<Value *> conditionOperands);
static StringRef getOperationName() { return "if"; }
static StringRef getConditionAttrName() { return "condition"; }
IntegerSet getIntegerSet() const;
void setIntegerSet(IntegerSet newSet);
/// Returns the list of 'then' blocks.
BlockList &getThenBlocks();
const BlockList &getThenBlocks() const {
return const_cast<AffineIfOp *>(this)->getThenBlocks();
}
/// Returns the list of 'else' blocks.
BlockList &getElseBlocks();
const BlockList &getElseBlocks() const {
return const_cast<AffineIfOp *>(this)->getElseBlocks();
}
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
private:
friend class OperationInst;
explicit AffineIfOp(const OperationInst *state) : Op(state) {}
};
} // end namespace mlir
#endif

View File

@ -128,7 +128,6 @@ private:
void matchOne(Instruction *elem);
void visitForInst(ForInst *forInst) { matchOne(forInst); }
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
/// POD paylod.

View File

@ -26,7 +26,6 @@
#include "llvm/ADT/PointerUnion.h"
namespace mlir {
class IfInst;
class BlockList;
class BlockAndValueMapping;
@ -62,7 +61,7 @@ public:
}
/// Returns the function that this block is part of, even if the block is
/// nested under an IfInst or ForInst.
/// nested under an OperationInst or ForInst.
Function *getFunction();
const Function *getFunction() const {
return const_cast<Block *>(this)->getFunction();
@ -325,7 +324,7 @@ private:
namespace mlir {
/// This class contains a list of basic blocks and has a notion of the object it
/// is part of - a Function or IfInst or ForInst.
/// is part of - a Function or OperationInst or ForInst.
class BlockList {
public:
explicit BlockList(Function *container);
@ -365,15 +364,16 @@ public:
return &BlockList::blocks;
}
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
/// part of an IfInst/ForInst, then return it, otherwise return null.
/// A BlockList is part of a function or an operation region. If it is
/// part of an operation region, then return the operation, otherwise return
/// null.
Instruction *getContainingInst();
const Instruction *getContainingInst() const {
return const_cast<BlockList *>(this)->getContainingInst();
}
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
/// part of a Function, then return it, otherwise return null.
/// A BlockList is part of a function or an operation region. If it is part
/// of a Function, then return it, otherwise return null.
Function *getContainingFunction();
const Function *getContainingFunction() const {
return const_cast<BlockList *>(this)->getContainingFunction();

View File

@ -286,10 +286,6 @@ public:
// Default step is 1.
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
/// Creates if instruction.
IfInst *createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set);
private:
Function *function;
Block *block = nullptr;

View File

@ -44,7 +44,7 @@
// lc.walk(function);
// numLoops = lc.numLoops;
//
// There are 'visit' methods for OperationInst, ForInst, IfInst, and
// There are 'visit' methods for OperationInst, ForInst, and
// Function, which recursively process all contained instructions.
//
// Note that if you don't implement visitXXX for some instruction type,
@ -85,8 +85,6 @@ public:
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
@ -104,7 +102,6 @@ public:
// When visiting a for inst, if inst, or an operation inst directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForInst(ForInst *forInst) {}
void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
@ -166,23 +163,6 @@ public:
static_cast<SubClass *>(this)->visitForInst(forInst);
}
void walkIfInst(IfInst *ifInst) {
static_cast<SubClass *>(this)->visitIfInst(ifInst);
static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(),
ifInst->getThen()->end());
if (auto *elseBlock = ifInst->getElse())
static_cast<SubClass *>(this)->walk(elseBlock->begin(), elseBlock->end());
}
void walkIfInstPostOrder(IfInst *ifInst) {
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(),
ifInst->getThen()->end());
if (auto *elseBlock = ifInst->getElse())
static_cast<SubClass *>(this)->walkPostOrder(elseBlock->begin(),
elseBlock->end());
static_cast<SubClass *>(this)->visitIfInst(ifInst);
}
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
@ -193,8 +173,6 @@ public:
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
@ -210,9 +188,6 @@ public:
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInstPostOrder(
cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->walkIfInstPostOrder(
cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInstPostOrder(
cast<OperationInst>(s));
@ -231,7 +206,6 @@ public:
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitForInst(ForInst *forInst) {}
void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
void visitInstruction(Instruction *inst) {}
};

View File

@ -75,7 +75,6 @@ public:
enum class Kind {
OperationInst = (int)IROperandOwner::Kind::OperationInst,
For = (int)IROperandOwner::Kind::ForInst,
If = (int)IROperandOwner::Kind::IfInst,
};
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }

View File

@ -794,130 +794,6 @@ private:
friend class ForInst;
};
/// If instruction restricts execution to a subset of the loop iteration space.
class IfInst : public Instruction {
public:
static IfInst *create(Location location, ArrayRef<Value *> operands,
IntegerSet set);
~IfInst();
//===--------------------------------------------------------------------===//
// Then, else, condition.
//===--------------------------------------------------------------------===//
Block *getThen() { return &thenClause.front(); }
const Block *getThen() const { return &thenClause.front(); }
Block *getElse() { return elseClause ? &elseClause->front() : nullptr; }
const Block *getElse() const {
return elseClause ? &elseClause->front() : nullptr;
}
bool hasElse() const { return elseClause != nullptr; }
Block *createElse() {
assert(elseClause == nullptr && "already has an else clause!");
elseClause = new BlockList(this);
elseClause->push_back(new Block());
return &elseClause->front();
}
const AffineCondition getCondition() const;
IntegerSet getIntegerSet() const { return set; }
void setIntegerSet(IntegerSet newSet) {
assert(newSet.getNumOperands() == operands.size());
set = newSet;
}
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
/// Operand iterators.
using operand_iterator = OperandIterator<IfInst, Value>;
using const_operand_iterator = OperandIterator<const IfInst, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
unsigned getNumOperands() const { return operands.size(); }
Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
const Value *getOperand(unsigned idx) const {
return getInstOperand(idx).get();
}
void setOperand(unsigned idx, Value *value) {
getInstOperand(idx).set(value);
}
operand_iterator operand_begin() { return operand_iterator(this, 0); }
operand_iterator operand_end() {
return operand_iterator(this, getNumOperands());
}
const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0);
}
const_operand_iterator operand_end() const {
return const_operand_iterator(this, getNumOperands());
}
ArrayRef<InstOperand> getInstOperands() const { return operands; }
MutableArrayRef<InstOperand> getInstOperands() { return operands; }
InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
const InstOperand &getInstOperand(unsigned idx) const {
return getInstOperands()[idx];
}
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
MLIRContext *getContext() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::IfInst;
}
private:
// it is always present.
BlockList thenClause;
// 'else' clause of the if instruction. 'nullptr' if there is no else clause.
BlockList *elseClause;
// The integer set capturing the conditional guard.
IntegerSet set;
// Condition operands.
std::vector<InstOperand> operands;
explicit IfInst(Location location, unsigned numOperands, IntegerSet set);
};
/// AffineCondition represents a condition of the 'if' instruction.
/// Its life span should not exceed that of the objects it refers to.
/// AffineCondition does not provide its own methods for iterating over
/// the operands since the iterators of the if instruction accomplish
/// the same purpose.
///
/// AffineCondition is trivially copyable, so it should be passed by value.
class AffineCondition {
public:
const IfInst *getIfInst() const { return &inst; }
IntegerSet getIntegerSet() const { return set; }
private:
// 'if' instruction that contains this affine condition.
const IfInst &inst;
// Integer set for this affine condition.
IntegerSet set;
AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {}
friend class IfInst;
};
} // end namespace mlir
#endif // MLIR_IR_INSTRUCTIONS_H

View File

@ -89,6 +89,9 @@ public:
/// Print the entire operation with the default generic assembly form.
virtual void printGenericOp(const OperationInst *op) = 0;
/// Prints a block list.
virtual void printBlockList(const BlockList &blocks) = 0;
private:
OpAsmPrinter(const OpAsmPrinter &) = delete;
void operator=(const OpAsmPrinter &) = delete;
@ -195,7 +198,19 @@ public:
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
bool parseKeywordType(const char *keyword, Type &result) {
return parseKeyword(keyword) || parseType(result);
}
/// Parse a keyword.
bool parseKeyword(const char *keyword) {
if (parseOptionalKeyword(keyword))
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'");
return false;
}
/// If a keyword is present, then parse it.
virtual bool parseOptionalKeyword(const char *keyword) = 0;
/// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
@ -296,6 +311,10 @@ public:
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
/// Parses a block list. Any parsed blocks are filled in to the
/// operation's block lists after the operation is created.
virtual bool parseBlockList() = 0;
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//

View File

@ -81,10 +81,9 @@ public:
enum class Kind {
OperationInst,
ForInst,
IfInst,
/// These enums define ranges used for classof implementations.
INST_LAST = IfInst,
INST_LAST = ForInst,
};
Kind getKind() const { return locationAndKind.getInt(); }

View File

@ -93,7 +93,7 @@ using OwningMLLoweringPatternList =
/// next _original_ operation is considered.
/// In other words, for each operation, the pass applies the first matching
/// rewriter in the list and advances to the (lexically) next operation.
/// Non-operation instructions (ForInst and IfInst) are ignored.
/// Non-operation instructions (ForInst) are ignored.
/// This is similar to greedy worklist-based pattern rewriter, except that this
/// operates on ML functions using an ML builder and does not maintain the work
/// list. Note that, as of the time of writing, worklist-based rewriter did not

View File

@ -0,0 +1,151 @@
//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// AffineOpsDialect
//===----------------------------------------------------------------------===//
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineIfOp>();
}
//===----------------------------------------------------------------------===//
// AffineIfOp
//===----------------------------------------------------------------------===//
void AffineIfOp::build(Builder *builder, OperationState *result,
IntegerSet condition,
ArrayRef<Value *> conditionOperands) {
result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition));
result->addOperands(conditionOperands);
// Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
result->reserveBlockLists(2);
}
bool AffineIfOp::verify() const {
// Verify that we have a condition attribute.
auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
if (!conditionAttr)
return emitOpError("requires an integer set attribute named 'condition'");
// Verify that the operands are valid dimension/symbols.
IntegerSet condition = conditionAttr.getValue();
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
const Value *operand = getOperand(i);
if (i < condition.getNumDims() && !operand->isValidDim())
return emitOpError("operand cannot be used as a dimension id");
if (i >= condition.getNumDims() && !operand->isValidSymbol())
return emitOpError("operand cannot be used as a symbol");
}
// Verify that the entry of each child blocklist does not have arguments.
for (const auto &blockList : getInstruction()->getBlockLists()) {
if (blockList.empty())
continue;
// TODO(riverriddle) We currently do not allow multiple blocks in child
// block lists.
if (std::next(blockList.begin()) != blockList.end())
return emitOpError(
"expects only one block per 'if' or 'else' block list");
if (blockList.front().getTerminator())
return emitOpError("expects region block to not have a terminator");
for (const auto &b : blockList)
if (b.getNumArguments() != 0)
return emitOpError(
"requires that child entry blocks have no arguments");
}
return false;
}
bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the condition attribute set.
IntegerSetAttr conditionAttr;
unsigned numDims;
if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(),
result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims))
return true;
// Verify the condition operands.
auto set = conditionAttr.getValue();
if (set.getNumDims() != numDims)
return parser->emitError(
parser->getNameLoc(),
"dim operand count and integer set dim count must match");
if (numDims + set.getNumSymbols() != result->operands.size())
return parser->emitError(
parser->getNameLoc(),
"symbol operand count and integer set symbol count must match");
// Parse the 'then' block list.
if (parser->parseBlockList())
return true;
// If we find an 'else' keyword then parse the else block list.
if (!parser->parseOptionalKeyword("else")) {
if (parser->parseBlockList())
return true;
}
// Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
result->reserveBlockLists(2);
return false;
}
void AffineIfOp::print(OpAsmPrinter *p) const {
auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
*p << "if " << conditionAttr;
printDimAndSymbolList(operand_begin(), operand_end(),
conditionAttr.getValue().getNumDims(), p);
p->printBlockList(getInstruction()->getBlockList(0));
// Print the 'else' block list if it has any blocks.
const auto &elseBlockList = getInstruction()->getBlockList(1);
if (!elseBlockList.empty()) {
*p << " else";
p->printBlockList(elseBlockList);
}
}
IntegerSet AffineIfOp::getIntegerSet() const {
return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
}
void AffineIfOp::setIntegerSet(IntegerSet newSet) {
setAttr(
Identifier::get(getConditionAttrName(), getInstruction()->getContext()),
IntegerSetAttr::get(newSet));
}
/// Returns the list of 'then' blocks.
BlockList &AffineIfOp::getThenBlocks() {
return getInstruction()->getBlockList(0);
}
/// Returns the list of 'else' blocks.
BlockList &AffineIfOp::getElseBlocks() {
return getInstruction()->getBlockList(1);
}

View File

@ -0,0 +1,22 @@
//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/AffineOps/AffineOps.h"
using namespace mlir;
// Static initialization for Affine op dialect registration.
static DialectRegistration<AffineOpsDialect> StandardOps;

View File

@ -21,6 +21,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/NestedMatcher.h"
@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
return false;
}
// No vectorization across unknown regions.
auto regions = matcher::Op([](const Instruction &inst) -> bool {
auto &opInst = cast<OperationInst>(inst);
return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
});
auto regionsMatched = regions.match(forInst);
if (!regionsMatched.empty()) {
return false;
}
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
auto vectorTransfersMatched = vectorTransfers.match(forInst);
if (!vectorTransfersMatched.empty()) {

View File

@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() {
return storage->filter;
}
static bool isAffineIfOp(const Instruction &inst) {
return isa<OperationInst>(inst) &&
cast<OperationInst>(inst).isa<AffineIfOp>();
}
namespace mlir {
namespace matcher {
@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) {
}
NestedPattern If(NestedPattern child) {
return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
return NestedPattern(Instruction::Kind::If, child, filter);
return NestedPattern(Instruction::Kind::OperationInst, child,
[filter](const Instruction &inst) {
return isAffineIfOp(inst) && filter(inst);
});
}
NestedPattern If(ArrayRef<NestedPattern> nested) {
return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
}
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
return NestedPattern(Instruction::Kind::If, nested, filter);
return NestedPattern(Instruction::Kind::OperationInst, nested,
[filter](const Instruction &inst) {
return isAffineIfOp(inst) && filter(inst);
});
}
NestedPattern For(NestedPattern child) {

View File

@ -22,6 +22,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Builders.h"
@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst,
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
isa<IfInst>(currInst))) {
cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
if (currForInst)
loops->push_back(currForInst);
currInst = currInst->getParentInst();
@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
if (auto *childForInst = dyn_cast<ForInst>(&inst))
return getInstAtPosition(positions, level + 1, childForInst->getBody());
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
if (ret != nullptr)
return ret;
if (auto *elseClause = ifInst->getElse())
return getInstAtPosition(positions, level + 1, elseClause);
}
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
for (auto &blockList : opInst->getBlockLists()) {
for (auto &b : blockList)
if (auto *ret = getInstAtPosition(positions, level + 1, &b))
return ret;
}
return nullptr;
for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
for (auto &b : blockList)
if (auto *ret = getInstAtPosition(positions, level + 1, &b))
return ret;
}
return nullptr;
}
return nullptr;
}

View File

@ -73,7 +73,6 @@ public:
bool verifyBlock(const Block &block, bool isTopLevel);
bool verifyOperation(const OperationInst &op);
bool verifyForInst(const ForInst &forInst);
bool verifyIfInst(const IfInst &ifInst);
bool verifyDominance(const Block &block);
bool verifyInstDominance(const Instruction &inst);
@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
if (verifyForInst(cast<ForInst>(inst)))
return true;
break;
case Instruction::Kind::If:
if (verifyIfInst(cast<IfInst>(inst)))
return true;
break;
}
}
@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) {
return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
}
bool FuncVerifier::verifyIfInst(const IfInst &ifInst) {
// TODO: check that if conditions are properly formed.
if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false))
return true;
if (auto *elseClause = ifInst.getElse())
if (verifyBlock(*elseClause, /*isTopLevel*/ false))
return true;
return false;
}
bool FuncVerifier::verifyDominance(const Block &block) {
for (auto &inst : block) {
// Check that all operands on the instruction are ok.
@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) {
if (verifyDominance(*cast<ForInst>(inst).getBody()))
return true;
break;
case Instruction::Kind::If:
auto &ifInst = cast<IfInst>(inst);
if (verifyDominance(*ifInst.getThen()))
return true;
if (auto *elseClause = ifInst.getElse())
if (verifyDominance(*elseClause))
return true;
break;
}
}
return false;

View File

@ -145,7 +145,6 @@ private:
// Visit functions.
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) {
}
}
void ModuleState::visitIfInst(const IfInst *ifInst) {
recordIntegerSetReference(ifInst->getIntegerSet());
}
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasCustomForm(lbMap))
@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::If:
return visitIfInst(cast<IfInst>(inst));
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
@ -1077,7 +1070,6 @@ public:
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForInst *inst);
void print(const IfInst *inst);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
@ -1125,6 +1117,9 @@ public:
unsigned index) override;
/// Print a block list.
void printBlockList(const BlockList &blocks) override {
printBlockList(blocks, /*printEntryBlockArgs=*/true);
}
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
if (!blocks.empty()) {
@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// Recursively number the stuff in the body.
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
case Instruction::Kind::If: {
auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
if (auto *elseBlock = ifInst->getElse())
numberValuesInBlock(*elseBlock);
}
}
}
}
@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() {
}
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
// Print the block label and argument list, unless this is the first block of
// the function, or the first block of an IfInst/ForInst with no arguments.
// Print the block label and argument list if requested.
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) {
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
case Instruction::Kind::If:
return print(cast<IfInst>(inst));
}
}
@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "}";
}
void FunctionPrinter::print(const IfInst *inst) {
os.indent(currentIndent) << "if ";
IntegerSet set = inst->getIntegerSet();
printIntegerSetReference(set);
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
printTrailingLocation(inst->getLoc());
os << " {\n";
print(inst->getThen(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
if (inst->hasElse()) {
os << " else {\n";
print(inst->getElse(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
}
void FunctionPrinter::printValueID(const Value *value,
bool printResultNo) const {
int resultNo = -1;

View File

@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
auto *inst = IfInst::create(location, operands, set);
block->getInstructions().insert(insertPoint, inst);
return inst;
}

View File

@ -73,9 +73,6 @@ void Instruction::destroy() {
case Kind::For:
delete cast<ForInst>(this);
break;
case Kind::If:
delete cast<IfInst>(this);
break;
}
}
@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const {
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForInst>(this)->getNumOperands();
case Kind::If:
return cast<IfInst>(this)->getNumOperands();
}
}
@ -152,8 +147,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
return cast<ForInst>(this)->getInstOperands();
case Kind::If:
return cast<IfInst>(this)->getInstOperands();
}
}
@ -287,15 +280,6 @@ void Instruction::dropAllReferences() {
// Make sure to drop references held by instructions within the body.
cast<ForInst>(this)->getBody()->dropAllReferences();
break;
case Kind::If: {
// Make sure to drop references held by instructions within the 'then' and
// 'else' blocks.
auto *ifInst = cast<IfInst>(this);
ifInst->getThen()->dropAllReferences();
if (auto *elseBlock = ifInst->getElse())
elseBlock->dropAllReferences();
break;
}
case Kind::OperationInst: {
auto *opInst = cast<OperationInst>(this);
if (isTerminator())
@ -809,54 +793,6 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
results.push_back(forInst->getInductionVar());
return results;
}
//===----------------------------------------------------------------------===//
// IfInst
//===----------------------------------------------------------------------===//
IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
: Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
// The then of an 'if' inst always has one block.
thenClause.push_back(new Block());
}
IfInst::~IfInst() {
if (elseClause)
delete elseClause;
// An IfInst's IntegerSet 'set' should not be deleted since it is
// allocated through MLIRContext's bump pointer allocator.
}
IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
unsigned numOperands = operands.size();
assert(numOperands == set.getNumOperands() &&
"operand cound does not match the integer set operand count");
IfInst *inst = new IfInst(location, numOperands, set);
for (auto *op : operands)
inst->operands.emplace_back(InstOperand(inst, op));
return inst;
}
const AffineCondition IfInst::getCondition() const {
return AffineCondition(*this, set);
}
MLIRContext *IfInst::getContext() const {
// Check for degenerate case of if instruction with no operands.
// This is unlikely, but legal.
if (operands.empty())
return getFunction()->getContext();
return getOperand(0)->getType().getContext();
}
//===----------------------------------------------------------------------===//
// Instruction Cloning
//===----------------------------------------------------------------------===//
@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
for (auto *opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
if (auto *forInst = dyn_cast<ForInst>(this)) {
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
// Otherwise, this must be a ForInst.
auto *forInst = cast<ForInst>(this);
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
auto *newFor = ForInst::create(
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
ubMap, forInst->getStep());
auto *newFor = ForInst::create(
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
forInst->getStep());
// Remember the induction variable mapping.
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Remember the induction variable mapping.
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())
newFor->getBody()->push_back(subInst.clone(mapper, context));
return newFor;
}
// Otherwise, we must have an If instruction.
auto *ifInst = cast<IfInst>(this);
auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
auto *resultThen = newIf->getThen();
for (auto &childInst : *ifInst->getThen())
resultThen->push_back(childInst.clone(mapper, context));
if (ifInst->hasElse()) {
auto *resultElse = newIf->createElse();
for (auto &childInst : *ifInst->getElse())
resultElse->push_back(childInst.clone(mapper, context));
}
return newIf;
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())
newFor->getBody()->push_back(subInst.clone(mapper, context));
return newFor;
}
Instruction *Instruction::clone(MLIRContext *context) const {

View File

@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
if (!block || &block->back() != op)
return op->emitOpError("must be the last instruction in the parent block");
// Terminators may not exist in ForInst and IfInst.
// TODO(riverriddle) Terminators may not exist with an operation region.
if (block->getContainingInst())
return op->emitOpError("may only be at the top level of a function");

View File

@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const {
return cast<OperationInst>(this)->getContext();
case Kind::ForInst:
return cast<ForInst>(this)->getContext();
case Kind::IfInst:
return cast<IfInst>(this)->getContext();
}
}

View File

@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) {
AffineMap map;
IntegerSet set;
if (parseAffineMapOrIntegerSetReference(map, set))
return (emitError("expected affine map or integer set attribute value"),
nullptr);
return nullptr;
if (map)
return builder.getAffineMapAttr(map);
assert(set);
@ -2209,8 +2208,6 @@ public:
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfInst();
ParseResult parseElseClause(Block *elseClause);
ParseResult parseInstructions(Block *block);
private:
@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
if (parseForInst())
return ParseFailure;
break;
case Token::kw_if:
if (parseIfInst())
return ParseFailure;
break;
}
}
@ -2935,12 +2928,18 @@ public:
return false;
}
/// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type &result) override {
if (parser.getTokenSpelling() != keyword)
return parser.emitError("expected '" + Twine(keyword) + "'");
parser.consumeToken();
return !(result = parser.parseType());
/// Parse an optional keyword.
bool parseOptionalKeyword(const char *keyword) override {
// Check that the current token is a bare identifier or keyword.
if (parser.getToken().isNot(Token::bare_identifier) &&
!parser.getToken().isKeyword())
return true;
if (parser.getTokenSpelling() == keyword) {
parser.consumeToken();
return false;
}
return true;
}
/// Parse an arbitrary attribute of a given type and return it in result. This
@ -3078,6 +3077,15 @@ public:
return result == nullptr;
}
/// Parses a list of blocks.
bool parseBlockList() override {
SmallVector<Block *, 2> results;
if (parser.parseOperationBlockList(results))
return true;
parsedBlockLists.emplace_back(results);
return false;
}
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@ -3099,6 +3107,11 @@ public:
/// Emit a diagnostic at the specified location and return true.
bool emitError(llvm::SMLoc loc, const Twine &message) override {
// If we emit an error, then cleanup any parsed block lists.
for (auto &blockList : parsedBlockLists)
parser.cleanupInvalidBlocks(blockList);
parsedBlockLists.clear();
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true;
return true;
@ -3106,7 +3119,13 @@ public:
bool didEmitError() const { return emittedError; }
/// Returns the block lists that were parsed.
MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() {
return parsedBlockLists;
}
private:
std::vector<SmallVector<Block *, 2>> parsedBlockLists;
SMLoc nameLoc;
StringRef opName;
FunctionParser &parser;
@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() {
if (opAsmParser.didEmitError())
return nullptr;
// Check that enough block lists were reserved for those that were parsed.
auto parsedBlockLists = opAsmParser.getParsedBlockLists();
if (parsedBlockLists.size() > opState.numBlockLists) {
opAsmParser.emitError(
opLoc,
"parsed more block lists than those reserved in the operation state");
return nullptr;
}
// Otherwise, we succeeded. Use the state it parsed as our op information.
return builder.createOperation(opState);
auto *opInst = builder.createOperation(opState);
// Resolve any parsed block lists.
for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
auto &opBlockList = opInst->getBlockList(i).getBlocks();
opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
parsedBlockLists[i].end());
}
return opInst;
}
/// For instruction.
@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
}
/// If instruction.
///
/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}`
/// | ml-if-head `else` `if` ml-if-cond trailing-location?
/// `{` inst* `}`
/// ml-if-inst ::= ml-if-head
/// | ml-if-head `else` `{` inst* `}`
///
ParseResult FunctionParser::parseIfInst() {
auto loc = getToken().getLoc();
consumeToken(Token::kw_if);
IntegerSet set = parseIntegerSetReference();
if (!set)
return ParseFailure;
SmallVector<Value *, 4> operands;
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
"integer set"))
return ParseFailure;
IfInst *ifInst =
builder.createIf(getEncodedSourceLocation(loc), operands, set);
// Try to parse the optional trailing location.
if (parseOptionalTrailingLocation(ifInst))
return ParseFailure;
Block *thenClause = ifInst->getThen();
// When parsing of an if instruction body fails, the IR contains
// the if instruction with the portion of the body that has been
// successfully parsed.
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
parseBlock(thenClause) ||
parseToken(Token::r_brace, "expected '}' after instruction list"))
return ParseFailure;
if (consumeIf(Token::kw_else)) {
auto *elseClause = ifInst->createElse();
if (parseElseClause(elseClause))
return ParseFailure;
}
// Reset insertion point to the current block.
builder.setInsertionPointToEnd(ifInst->getBlock());
return ParseSuccess;
}
ParseResult FunctionParser::parseElseClause(Block *elseClause) {
if (getToken().is(Token::kw_if)) {
builder.setInsertionPointToEnd(elseClause);
return parseIfInst();
}
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
parseBlock(elseClause) ||
parseToken(Token::r_brace, "expected '}' after instruction list"))
return ParseFailure;
return ParseSuccess;
}
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//

View File

@ -91,7 +91,6 @@ TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(dense)
TOK_KEYWORD(else)
TOK_KEYWORD(splat)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
@ -100,7 +99,6 @@ TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)
TOK_KEYWORD(func)
TOK_KEYWORD(if)
TOK_KEYWORD(index)
TOK_KEYWORD(loc)
TOK_KEYWORD(max)

View File

@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) {
simplifyBlock(cast<ForInst>(i).getBody());
break;
}
case Instruction::Kind::If: {
auto &ifInst = cast<IfInst>(i);
if (auto *elseBlock = ifInst.getElse()) {
ScopedMapTy::ScopeTy scope(knownValues);
simplifyBlock(elseBlock);
}
ScopedMapTy::ScopeTy scope(knownValues);
simplifyBlock(ifInst.getThen());
break;
}
}
}
}

View File

@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@ -99,16 +100,16 @@ public:
SmallVector<ForInst *, 4> forInsts;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
bool hasIfInst = false;
bool hasNonForRegion = false;
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
void visitOperationInst(OperationInst *opInst) {
if (opInst->isa<LoadOp>())
if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
if (opInst->isa<StoreOp>())
else if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
}
};
@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) {
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForInst(forInst);
// Return false if IfInsts are found (not currently supported).
if (collector.hasIfInst)
// Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
Node node(id++, &inst);
for (auto *opInst : collector.loadOpInsts) {
@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) {
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(id++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (opInst->getNumBlockLists() != 0) {
// Return false if another region is found (not currently supported).
return false;
}
}
// Return false if IfInsts are found (not currently supported).
if (isa<IfInst>(&inst))
return false;
}
// Walk memref access lists and add graph edges between dependent nodes.

View File

@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
return true;
}
bool walkIfInstPostOrder(IfInst *ifInst) {
bool hasInnerLoops =
walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
if (ifInst->hasElse())
hasInnerLoops |=
walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
return hasInnerLoops;
}
bool walkOpInstPostOrder(OperationInst *opInst) {
for (auto &blockList : opInst->getBlockLists())
for (auto &block : blockList)

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -246,7 +247,7 @@ public:
PassResult runOnFunction(Function *function) override;
bool lowerForInst(ForInst *forInst);
bool lowerIfInst(IfInst *ifInst);
bool lowerAffineIf(AffineIfOp *ifOp);
bool lowerAffineApply(AffineApplyOp *op);
static char passID;
@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
// | <code before the IfInst> |
// | <code before the AffineIfOp> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
// v v
// +--------------------------------+
// | continue: |
// | <code after the IfInst> |
// | <code after the AffineIfOp> |
// +--------------------------------+
//
bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
auto *ifInst = ifOp->getInstruction();
auto loc = ifInst->getLoc();
// Start by splitting the block containing the 'if' into two parts. The part
@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
auto *continueBlock = condBlock->splitBlock(ifInst);
// Create a block for the 'then' code, inserting it between the cond and
// continue blocks. Move the instructions over from the IfInst and add a
// continue blocks. Move the instructions over from the AffineIfOp and add a
// branch to the continuation point.
Block *thenBlock = new Block();
thenBlock->insertBefore(continueBlock);
auto *oldThen = ifInst->getThen();
thenBlock->getInstructions().splice(thenBlock->begin(),
oldThen->getInstructions(),
oldThen->begin(), oldThen->end());
// If the 'then' block is not empty, then splice the instructions.
auto &oldThenBlocks = ifOp->getThenBlocks();
if (!oldThenBlocks.empty()) {
// We currently only handle one 'then' block.
if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
return true;
Block *oldThen = &oldThenBlocks.front();
thenBlock->getInstructions().splice(thenBlock->begin(),
oldThen->getInstructions(),
oldThen->begin(), oldThen->end());
}
FuncBuilder builder(thenBlock);
builder.create<BranchOp>(loc, continueBlock);
// Handle the 'else' block the same way, but we skip it if we have no else
// code.
Block *elseBlock = continueBlock;
if (auto *oldElse = ifInst->getElse()) {
auto &oldElseBlocks = ifOp->getElseBlocks();
if (!oldElseBlocks.empty()) {
// We currently only handle one 'else' block.
if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
return true;
auto *oldElse = &oldElseBlocks.front();
elseBlock = new Block();
elseBlock->insertBefore(continueBlock);
@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
}
// Ok, now we just have to handle the condition logic.
auto integerSet = ifInst->getCondition().getIntegerSet();
auto integerSet = ifOp->getIntegerSet();
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) {
PassResult LowerAffinePass::runOnFunction(Function *function) {
SmallVector<Instruction *, 8> instsToRewrite;
// Collect all the If and For instructions as well as AffineApplyOps. We do
// this as a prepass to avoid invalidating the walker with our rewrite.
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
// We do this as a prepass to avoid invalidating the walker with our rewrite.
function->walkInsts([&](Instruction *inst) {
if (isa<IfInst>(inst) || isa<ForInst>(inst))
if (isa<ForInst>(inst))
instsToRewrite.push_back(inst);
auto op = dyn_cast<OperationInst>(inst);
if (op && op->isa<AffineApplyOp>())
if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
instsToRewrite.push_back(inst);
});
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
// so we know that we will rewrite them in the same order.
for (auto *inst : instsToRewrite)
if (auto *ifInst = dyn_cast<IfInst>(inst)) {
if (lowerIfInst(ifInst))
return failure();
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (auto *forInst = dyn_cast<ForInst>(inst)) {
if (lowerForInst(forInst))
return failure();
} else {
auto op = cast<OperationInst>(inst);
if (lowerAffineApply(op->cast<AffineApplyOp>()))
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
if (lowerAffineIf(ifOp))
return failure();
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
return failure();
}
}
return success();

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst,
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
if (isa<IfInst>(inst))
return inst->emitError("NYI path IfInst");
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst,
if (opInst->isa<AffineApplyOp>()) {
return false;
}
if (opInst->getNumBlockLists() != 0)
return inst->emitError("NYI path Op with region");
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
auto *clone = instantiate(&b, write, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);

View File

@ -28,7 +28,6 @@
#define DEBUG_TYPE "simplify-affine-structure"
using namespace mlir;
using llvm::report_fatal_error;
namespace {
@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass {
PassResult runOnFunction(Function *f) override;
void visitIfInst(IfInst *ifInst);
void visitOperationInst(OperationInst *opInst);
static char passID;
};
@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
auto set = ifInst->getCondition().getIntegerSet();
ifInst->setIntegerSet(simplifyIntegerSet(set));
}
void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify();
auto map = mMap.getAffineMap();
opInst->setAttr(attr.first, AffineMapAttr::get(map));
}
}
}
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
f->walkInsts([&](Instruction *inst) {
if (auto *opInst = dyn_cast<OperationInst>(inst))
visitOperationInst(opInst);
if (auto *ifInst = dyn_cast<IfInst>(inst))
visitIfInst(ifInst);
f->walkOps([&](OperationInst *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify();
auto map = mMap.getAffineMap();
opInst->setAttr(attr.first, AffineMapAttr::get(map));
} else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) {
auto simplified = simplifyIntegerSet(setAttr.getValue());
opInst->setAttr(attr.first, IntegerSetAttr::get(simplified));
}
}
});
return success();

View File

@ -243,14 +243,6 @@ func @non_instruction() {
// -----
func @invalid_if_conditional1() {
for %i = 1 to 10 {
if () { // expected-error {{expected ':' or '['}}
}
}
// -----
func @invalid_if_conditional2() {
for %i = 1 to 10 {
if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
@ -664,7 +656,11 @@ func @invalid_if_operands2(%N : index) {
func @invalid_if_operands3(%N : index) {
for %i = 1 to 10 {
if #set0(%i)[%i] {
// expected-error@-1 {{value '%i' cannot be used as a symbol}}
// expected-error@-1 {{operand cannot be used as a symbol}}
}
}
return
}
// -----
// expected-error@+1 {{expected '"' in string literal}}

View File

@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: ) loc(fused<"myPass">["foo", "foo2"])
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
}
// CHECK: } loc(fused<"myPass">["foo", "foo2"])
if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])
// CHECK: return %0 : i32 loc(unknown)
return %1 : i32 loc(unknown)

View File

@ -287,13 +287,15 @@ func @ifinst(%N: index) {
// CHECK: %c1_i32 = constant 1 : i32
%y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32
%z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32
} else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) {
// CHECK: %c1 = constant 1 : index
%u = constant 1 : index
// CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
%w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
} else { // CHECK } else {
%v = constant 3 : i32 // %c3_i32 = constant 3 : i32
} else { // CHECK } else {
if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) {
// CHECK: %c1 = constant 1 : index
%u = constant 1 : index
// CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
%w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
} else { // CHECK } else {
%v = constant 3 : i32 // %c3_i32 = constant 3 : i32
}
} // CHECK }
} // CHECK }
return // CHECK return
@ -751,11 +753,11 @@ func @type_alias() -> !i32_type_alias {
func @verbose_if(%N: index) {
%c = constant 200 : index
// CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () {
"if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () {
// CHECK: if #set0(%c200)[%arg0, %c200] {
"if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () {
// CHECK-NEXT: "add"
%y = "add"(%c, %N) : (index, index) -> index
// CHECK-NEXT: } {
// CHECK-NEXT: } else {
} { // The else block list.
// CHECK-NEXT: "add"
%z = "add"(%c, %c) : (index, index) -> index

View File

@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: ) <"myPass">["foo", "foo2"]
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
}
// CHECK: } <"myPass">["foo", "foo2"]
if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])
// CHECK: return %0 : i32 [unknown]
return %1 : i32 loc(unknown)
}
}

View File

@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() {
%c0 = constant 4 : index
if #set0(%c0) {
}
// Top-level IfInst should prevent fusion.
// Top-level IfOp should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() {
%v0 = load %m[%i1] : memref<10xf32>
}
// IfInst in ForInst should prevent fusion.
// IfOp in ForInst should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }

View File

@ -10,7 +10,7 @@ func @store_may_execute_before_load() {
%cf7 = constant 7.0 : f32
%c0 = constant 4 : index
// There is a dependence from store 0 to load 1 at depth 1 because the
// ancestor IfInst of the store, dominates the ancestor ForSmt of the load,
// ancestor IfOp of the store, dominates the ancestor ForSmt of the load,
// and thus the store "may" conditionally execute before the load.
if #set0(%c0) {
for %i0 = 0 to 10 {

View File

@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
}
// CHECK: if #set0(%c4) loc(unknown)
// CHECK: } loc(unknown)
%2 = constant 4 : index
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
}
if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])
// CHECK: return %0 : i32 loc(unknown)
return %1 : i32 loc("bar")