forked from OSchip/llvm-project
Recommit: Define a AffineOps dialect as well as an AffineIfOp operation. Replace all instances of IfInst with AffineIfOp and delete IfInst.
PiperOrigin-RevId: 231342063
This commit is contained in:
parent
39d81f246a
commit
755538328b
|
@ -0,0 +1,91 @@
|
|||
//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines convenience types for working with Affine operations
|
||||
// in the MLIR instruction set.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_AFFINEOPS_AFFINEOPS_H
|
||||
#define MLIR_AFFINEOPS_AFFINEOPS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineOpsDialect : public Dialect {
|
||||
public:
|
||||
AffineOpsDialect(MLIRContext *context);
|
||||
};
|
||||
|
||||
/// The "if" operation represents an if–then–else construct for conditionally
|
||||
/// executing two regions of code. The operands to an if operation are an
|
||||
/// IntegerSet condition and a set of symbol/dimension operands to the
|
||||
/// condition set. The operation produces no results. For example:
|
||||
///
|
||||
/// if #set(%i) {
|
||||
/// ...
|
||||
/// } else {
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
/// The 'else' blocks to the if operation are optional, and may be omitted. For
|
||||
/// example:
|
||||
///
|
||||
/// if #set(%i) {
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
class AffineIfOp
|
||||
: public Op<AffineIfOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
IntegerSet condition, ArrayRef<Value *> conditionOperands);
|
||||
|
||||
static StringRef getOperationName() { return "if"; }
|
||||
static StringRef getConditionAttrName() { return "condition"; }
|
||||
|
||||
IntegerSet getIntegerSet() const;
|
||||
void setIntegerSet(IntegerSet newSet);
|
||||
|
||||
/// Returns the list of 'then' blocks.
|
||||
BlockList &getThenBlocks();
|
||||
const BlockList &getThenBlocks() const {
|
||||
return const_cast<AffineIfOp *>(this)->getThenBlocks();
|
||||
}
|
||||
|
||||
/// Returns the list of 'else' blocks.
|
||||
BlockList &getElseBlocks();
|
||||
const BlockList &getElseBlocks() const {
|
||||
return const_cast<AffineIfOp *>(this)->getElseBlocks();
|
||||
}
|
||||
|
||||
bool verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
||||
private:
|
||||
friend class OperationInst;
|
||||
explicit AffineIfOp(const OperationInst *state) : Op(state) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
|
@ -128,7 +128,6 @@ private:
|
|||
void matchOne(Instruction *elem);
|
||||
|
||||
void visitForInst(ForInst *forInst) { matchOne(forInst); }
|
||||
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
|
||||
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
|
||||
|
||||
/// POD paylod.
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "llvm/ADT/PointerUnion.h"
|
||||
|
||||
namespace mlir {
|
||||
class IfInst;
|
||||
class BlockList;
|
||||
class BlockAndValueMapping;
|
||||
|
||||
|
@ -62,7 +61,7 @@ public:
|
|||
}
|
||||
|
||||
/// Returns the function that this block is part of, even if the block is
|
||||
/// nested under an IfInst or ForInst.
|
||||
/// nested under an OperationInst or ForInst.
|
||||
Function *getFunction();
|
||||
const Function *getFunction() const {
|
||||
return const_cast<Block *>(this)->getFunction();
|
||||
|
@ -325,7 +324,7 @@ private:
|
|||
namespace mlir {
|
||||
|
||||
/// This class contains a list of basic blocks and has a notion of the object it
|
||||
/// is part of - a Function or IfInst or ForInst.
|
||||
/// is part of - a Function or OperationInst or ForInst.
|
||||
class BlockList {
|
||||
public:
|
||||
explicit BlockList(Function *container);
|
||||
|
@ -365,15 +364,16 @@ public:
|
|||
return &BlockList::blocks;
|
||||
}
|
||||
|
||||
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
|
||||
/// part of an IfInst/ForInst, then return it, otherwise return null.
|
||||
/// A BlockList is part of a function or an operation region. If it is
|
||||
/// part of an operation region, then return the operation, otherwise return
|
||||
/// null.
|
||||
Instruction *getContainingInst();
|
||||
const Instruction *getContainingInst() const {
|
||||
return const_cast<BlockList *>(this)->getContainingInst();
|
||||
}
|
||||
|
||||
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
|
||||
/// part of a Function, then return it, otherwise return null.
|
||||
/// A BlockList is part of a function or an operation region. If it is part
|
||||
/// of a Function, then return it, otherwise return null.
|
||||
Function *getContainingFunction();
|
||||
const Function *getContainingFunction() const {
|
||||
return const_cast<BlockList *>(this)->getContainingFunction();
|
||||
|
|
|
@ -286,10 +286,6 @@ public:
|
|||
// Default step is 1.
|
||||
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
|
||||
|
||||
/// Creates if instruction.
|
||||
IfInst *createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
|
||||
private:
|
||||
Function *function;
|
||||
Block *block = nullptr;
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
// lc.walk(function);
|
||||
// numLoops = lc.numLoops;
|
||||
//
|
||||
// There are 'visit' methods for OperationInst, ForInst, IfInst, and
|
||||
// There are 'visit' methods for OperationInst, ForInst, and
|
||||
// Function, which recursively process all contained instructions.
|
||||
//
|
||||
// Note that if you don't implement visitXXX for some instruction type,
|
||||
|
@ -85,8 +85,6 @@ public:
|
|||
switch (s->getKind()) {
|
||||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->visitOperationInst(
|
||||
cast<OperationInst>(s));
|
||||
|
@ -104,7 +102,6 @@ public:
|
|||
// When visiting a for inst, if inst, or an operation inst directly, these
|
||||
// methods get called to indicate when transitioning into a new unit.
|
||||
void visitForInst(ForInst *forInst) {}
|
||||
void visitIfInst(IfInst *ifInst) {}
|
||||
void visitOperationInst(OperationInst *opInst) {}
|
||||
};
|
||||
|
||||
|
@ -166,23 +163,6 @@ public:
|
|||
static_cast<SubClass *>(this)->visitForInst(forInst);
|
||||
}
|
||||
|
||||
void walkIfInst(IfInst *ifInst) {
|
||||
static_cast<SubClass *>(this)->visitIfInst(ifInst);
|
||||
static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(),
|
||||
ifInst->getThen()->end());
|
||||
if (auto *elseBlock = ifInst->getElse())
|
||||
static_cast<SubClass *>(this)->walk(elseBlock->begin(), elseBlock->end());
|
||||
}
|
||||
|
||||
void walkIfInstPostOrder(IfInst *ifInst) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(),
|
||||
ifInst->getThen()->end());
|
||||
if (auto *elseBlock = ifInst->getElse())
|
||||
static_cast<SubClass *>(this)->walkPostOrder(elseBlock->begin(),
|
||||
elseBlock->end());
|
||||
static_cast<SubClass *>(this)->visitIfInst(ifInst);
|
||||
}
|
||||
|
||||
// Function to walk a instruction.
|
||||
RetTy walk(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||
|
@ -193,8 +173,6 @@ public:
|
|||
switch (s->getKind()) {
|
||||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
|
||||
}
|
||||
|
@ -210,9 +188,6 @@ public:
|
|||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForInstPostOrder(
|
||||
cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfInstPostOrder(
|
||||
cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpInstPostOrder(
|
||||
cast<OperationInst>(s));
|
||||
|
@ -231,7 +206,6 @@ public:
|
|||
// processing their descendants in some way. When using RetTy, all of these
|
||||
// need to be overridden.
|
||||
void visitForInst(ForInst *forInst) {}
|
||||
void visitIfInst(IfInst *ifInst) {}
|
||||
void visitOperationInst(OperationInst *opInst) {}
|
||||
void visitInstruction(Instruction *inst) {}
|
||||
};
|
||||
|
|
|
@ -75,7 +75,6 @@ public:
|
|||
enum class Kind {
|
||||
OperationInst = (int)IROperandOwner::Kind::OperationInst,
|
||||
For = (int)IROperandOwner::Kind::ForInst,
|
||||
If = (int)IROperandOwner::Kind::IfInst,
|
||||
};
|
||||
|
||||
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }
|
||||
|
|
|
@ -794,130 +794,6 @@ private:
|
|||
|
||||
friend class ForInst;
|
||||
};
|
||||
|
||||
/// If instruction restricts execution to a subset of the loop iteration space.
|
||||
class IfInst : public Instruction {
|
||||
public:
|
||||
static IfInst *create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
~IfInst();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Then, else, condition.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
Block *getThen() { return &thenClause.front(); }
|
||||
const Block *getThen() const { return &thenClause.front(); }
|
||||
Block *getElse() { return elseClause ? &elseClause->front() : nullptr; }
|
||||
const Block *getElse() const {
|
||||
return elseClause ? &elseClause->front() : nullptr;
|
||||
}
|
||||
bool hasElse() const { return elseClause != nullptr; }
|
||||
|
||||
Block *createElse() {
|
||||
assert(elseClause == nullptr && "already has an else clause!");
|
||||
elseClause = new BlockList(this);
|
||||
elseClause->push_back(new Block());
|
||||
return &elseClause->front();
|
||||
}
|
||||
|
||||
const AffineCondition getCondition() const;
|
||||
|
||||
IntegerSet getIntegerSet() const { return set; }
|
||||
void setIntegerSet(IntegerSet newSet) {
|
||||
assert(newSet.getNumOperands() == operands.size());
|
||||
set = newSet;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operands
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Operand iterators.
|
||||
using operand_iterator = OperandIterator<IfInst, Value>;
|
||||
using const_operand_iterator = OperandIterator<const IfInst, const Value>;
|
||||
|
||||
/// Operand iterator range.
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
unsigned getNumOperands() const { return operands.size(); }
|
||||
|
||||
Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return getInstOperand(idx).get();
|
||||
}
|
||||
void setOperand(unsigned idx, Value *value) {
|
||||
getInstOperand(idx).set(value);
|
||||
}
|
||||
|
||||
operand_iterator operand_begin() { return operand_iterator(this, 0); }
|
||||
operand_iterator operand_end() {
|
||||
return operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
const_operand_iterator operand_begin() const {
|
||||
return const_operand_iterator(this, 0);
|
||||
}
|
||||
const_operand_iterator operand_end() const {
|
||||
return const_operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
ArrayRef<InstOperand> getInstOperands() const { return operands; }
|
||||
MutableArrayRef<InstOperand> getInstOperands() { return operands; }
|
||||
InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
|
||||
const InstOperand &getInstOperand(unsigned idx) const {
|
||||
return getInstOperands()[idx];
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const IROperandOwner *ptr) {
|
||||
return ptr->getKind() == IROperandOwner::Kind::IfInst;
|
||||
}
|
||||
|
||||
private:
|
||||
// it is always present.
|
||||
BlockList thenClause;
|
||||
// 'else' clause of the if instruction. 'nullptr' if there is no else clause.
|
||||
BlockList *elseClause;
|
||||
|
||||
// The integer set capturing the conditional guard.
|
||||
IntegerSet set;
|
||||
|
||||
// Condition operands.
|
||||
std::vector<InstOperand> operands;
|
||||
|
||||
explicit IfInst(Location location, unsigned numOperands, IntegerSet set);
|
||||
};
|
||||
|
||||
/// AffineCondition represents a condition of the 'if' instruction.
|
||||
/// Its life span should not exceed that of the objects it refers to.
|
||||
/// AffineCondition does not provide its own methods for iterating over
|
||||
/// the operands since the iterators of the if instruction accomplish
|
||||
/// the same purpose.
|
||||
///
|
||||
/// AffineCondition is trivially copyable, so it should be passed by value.
|
||||
class AffineCondition {
|
||||
public:
|
||||
const IfInst *getIfInst() const { return &inst; }
|
||||
IntegerSet getIntegerSet() const { return set; }
|
||||
|
||||
private:
|
||||
// 'if' instruction that contains this affine condition.
|
||||
const IfInst &inst;
|
||||
// Integer set for this affine condition.
|
||||
IntegerSet set;
|
||||
|
||||
AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {}
|
||||
|
||||
friend class IfInst;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_INSTRUCTIONS_H
|
||||
|
|
|
@ -89,6 +89,9 @@ public:
|
|||
/// Print the entire operation with the default generic assembly form.
|
||||
virtual void printGenericOp(const OperationInst *op) = 0;
|
||||
|
||||
/// Prints a block list.
|
||||
virtual void printBlockList(const BlockList &blocks) = 0;
|
||||
|
||||
private:
|
||||
OpAsmPrinter(const OpAsmPrinter &) = delete;
|
||||
void operator=(const OpAsmPrinter &) = delete;
|
||||
|
@ -195,7 +198,19 @@ public:
|
|||
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
|
||||
|
||||
/// Parse a keyword followed by a type.
|
||||
virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
|
||||
bool parseKeywordType(const char *keyword, Type &result) {
|
||||
return parseKeyword(keyword) || parseType(result);
|
||||
}
|
||||
|
||||
/// Parse a keyword.
|
||||
bool parseKeyword(const char *keyword) {
|
||||
if (parseOptionalKeyword(keyword))
|
||||
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'");
|
||||
return false;
|
||||
}
|
||||
|
||||
/// If a keyword is present, then parse it.
|
||||
virtual bool parseOptionalKeyword(const char *keyword) = 0;
|
||||
|
||||
/// Add the specified type to the end of the specified type list and return
|
||||
/// false. This is a helper designed to allow parse methods to be simple and
|
||||
|
@ -296,6 +311,10 @@ public:
|
|||
int requiredOperandCount = -1,
|
||||
Delimiter delimiter = Delimiter::None) = 0;
|
||||
|
||||
/// Parses a block list. Any parsed blocks are filled in to the
|
||||
/// operation's block lists after the operation is created.
|
||||
virtual bool parseBlockList() = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Methods for interacting with the parser
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -81,10 +81,9 @@ public:
|
|||
enum class Kind {
|
||||
OperationInst,
|
||||
ForInst,
|
||||
IfInst,
|
||||
|
||||
/// These enums define ranges used for classof implementations.
|
||||
INST_LAST = IfInst,
|
||||
INST_LAST = ForInst,
|
||||
};
|
||||
|
||||
Kind getKind() const { return locationAndKind.getInt(); }
|
||||
|
|
|
@ -93,7 +93,7 @@ using OwningMLLoweringPatternList =
|
|||
/// next _original_ operation is considered.
|
||||
/// In other words, for each operation, the pass applies the first matching
|
||||
/// rewriter in the list and advances to the (lexically) next operation.
|
||||
/// Non-operation instructions (ForInst and IfInst) are ignored.
|
||||
/// Non-operation instructions (ForInst) are ignored.
|
||||
/// This is similar to greedy worklist-based pattern rewriter, except that this
|
||||
/// operates on ML functions using an ML builder and does not maintain the work
|
||||
/// list. Note that, as of the time of writing, worklist-based rewriter did not
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineOpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
|
||||
: Dialect(/*namePrefix=*/"", context) {
|
||||
addOperations<AffineIfOp>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AffineIfOp::build(Builder *builder, OperationState *result,
|
||||
IntegerSet condition,
|
||||
ArrayRef<Value *> conditionOperands) {
|
||||
result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition));
|
||||
result->addOperands(conditionOperands);
|
||||
|
||||
// Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
|
||||
result->reserveBlockLists(2);
|
||||
}
|
||||
|
||||
bool AffineIfOp::verify() const {
|
||||
// Verify that we have a condition attribute.
|
||||
auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
|
||||
if (!conditionAttr)
|
||||
return emitOpError("requires an integer set attribute named 'condition'");
|
||||
|
||||
// Verify that the operands are valid dimension/symbols.
|
||||
IntegerSet condition = conditionAttr.getValue();
|
||||
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
|
||||
const Value *operand = getOperand(i);
|
||||
if (i < condition.getNumDims() && !operand->isValidDim())
|
||||
return emitOpError("operand cannot be used as a dimension id");
|
||||
if (i >= condition.getNumDims() && !operand->isValidSymbol())
|
||||
return emitOpError("operand cannot be used as a symbol");
|
||||
}
|
||||
|
||||
// Verify that the entry of each child blocklist does not have arguments.
|
||||
for (const auto &blockList : getInstruction()->getBlockLists()) {
|
||||
if (blockList.empty())
|
||||
continue;
|
||||
|
||||
// TODO(riverriddle) We currently do not allow multiple blocks in child
|
||||
// block lists.
|
||||
if (std::next(blockList.begin()) != blockList.end())
|
||||
return emitOpError(
|
||||
"expects only one block per 'if' or 'else' block list");
|
||||
if (blockList.front().getTerminator())
|
||||
return emitOpError("expects region block to not have a terminator");
|
||||
|
||||
for (const auto &b : blockList)
|
||||
if (b.getNumArguments() != 0)
|
||||
return emitOpError(
|
||||
"requires that child entry blocks have no arguments");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
// Parse the condition attribute set.
|
||||
IntegerSetAttr conditionAttr;
|
||||
unsigned numDims;
|
||||
if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(),
|
||||
result->attributes) ||
|
||||
parseDimAndSymbolList(parser, result->operands, numDims))
|
||||
return true;
|
||||
|
||||
// Verify the condition operands.
|
||||
auto set = conditionAttr.getValue();
|
||||
if (set.getNumDims() != numDims)
|
||||
return parser->emitError(
|
||||
parser->getNameLoc(),
|
||||
"dim operand count and integer set dim count must match");
|
||||
if (numDims + set.getNumSymbols() != result->operands.size())
|
||||
return parser->emitError(
|
||||
parser->getNameLoc(),
|
||||
"symbol operand count and integer set symbol count must match");
|
||||
|
||||
// Parse the 'then' block list.
|
||||
if (parser->parseBlockList())
|
||||
return true;
|
||||
|
||||
// If we find an 'else' keyword then parse the else block list.
|
||||
if (!parser->parseOptionalKeyword("else")) {
|
||||
if (parser->parseBlockList())
|
||||
return true;
|
||||
}
|
||||
|
||||
// Reserve 2 block lists, one for the 'then' and one for the 'else' regions.
|
||||
result->reserveBlockLists(2);
|
||||
return false;
|
||||
}
|
||||
|
||||
void AffineIfOp::print(OpAsmPrinter *p) const {
|
||||
auto conditionAttr = getAttrOfType<IntegerSetAttr>(getConditionAttrName());
|
||||
*p << "if " << conditionAttr;
|
||||
printDimAndSymbolList(operand_begin(), operand_end(),
|
||||
conditionAttr.getValue().getNumDims(), p);
|
||||
p->printBlockList(getInstruction()->getBlockList(0));
|
||||
|
||||
// Print the 'else' block list if it has any blocks.
|
||||
const auto &elseBlockList = getInstruction()->getBlockList(1);
|
||||
if (!elseBlockList.empty()) {
|
||||
*p << " else";
|
||||
p->printBlockList(elseBlockList);
|
||||
}
|
||||
}
|
||||
|
||||
IntegerSet AffineIfOp::getIntegerSet() const {
|
||||
return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
|
||||
}
|
||||
void AffineIfOp::setIntegerSet(IntegerSet newSet) {
|
||||
setAttr(
|
||||
Identifier::get(getConditionAttrName(), getInstruction()->getContext()),
|
||||
IntegerSetAttr::get(newSet));
|
||||
}
|
||||
|
||||
/// Returns the list of 'then' blocks.
|
||||
BlockList &AffineIfOp::getThenBlocks() {
|
||||
return getInstruction()->getBlockList(0);
|
||||
}
|
||||
|
||||
/// Returns the list of 'else' blocks.
|
||||
BlockList &AffineIfOp::getElseBlocks() {
|
||||
return getInstruction()->getBlockList(1);
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
using namespace mlir;
|
||||
|
||||
// Static initialization for Affine op dialect registration.
|
||||
static DialectRegistration<AffineOpsDialect> StandardOps;
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
|
@ -246,6 +247,16 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
|
|||
return false;
|
||||
}
|
||||
|
||||
// No vectorization across unknown regions.
|
||||
auto regions = matcher::Op([](const Instruction &inst) -> bool {
|
||||
auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
|
||||
});
|
||||
auto regionsMatched = regions.match(forInst);
|
||||
if (!regionsMatched.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
|
||||
auto vectorTransfersMatched = vectorTransfers.match(forInst);
|
||||
if (!vectorTransfersMatched.empty()) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
@ -186,6 +187,11 @@ FilterFunctionType NestedPattern::getFilterFunction() {
|
|||
return storage->filter;
|
||||
}
|
||||
|
||||
static bool isAffineIfOp(const Instruction &inst) {
|
||||
return isa<OperationInst>(inst) &&
|
||||
cast<OperationInst>(inst).isa<AffineIfOp>();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace matcher {
|
||||
|
||||
|
@ -194,16 +200,22 @@ NestedPattern Op(FilterFunctionType filter) {
|
|||
}
|
||||
|
||||
NestedPattern If(NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction);
|
||||
return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp);
|
||||
}
|
||||
NestedPattern If(FilterFunctionType filter, NestedPattern child) {
|
||||
return NestedPattern(Instruction::Kind::If, child, filter);
|
||||
return NestedPattern(Instruction::Kind::OperationInst, child,
|
||||
[filter](const Instruction &inst) {
|
||||
return isAffineIfOp(inst) && filter(inst);
|
||||
});
|
||||
}
|
||||
NestedPattern If(ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction);
|
||||
return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp);
|
||||
}
|
||||
NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
|
||||
return NestedPattern(Instruction::Kind::If, nested, filter);
|
||||
return NestedPattern(Instruction::Kind::OperationInst, nested,
|
||||
[filter](const Instruction &inst) {
|
||||
return isAffineIfOp(inst) && filter(inst);
|
||||
});
|
||||
}
|
||||
|
||||
NestedPattern For(NestedPattern child) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -43,7 +44,7 @@ void mlir::getLoopIVs(const Instruction &inst,
|
|||
// Traverse up the hierarchy collecing all 'for' instruction while skipping
|
||||
// over 'if' instructions.
|
||||
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
|
||||
isa<IfInst>(currInst))) {
|
||||
cast<OperationInst>(currInst)->isa<AffineIfOp>())) {
|
||||
if (currForInst)
|
||||
loops->push_back(currForInst);
|
||||
currInst = currInst->getParentInst();
|
||||
|
@ -359,21 +360,12 @@ static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
|
|||
if (auto *childForInst = dyn_cast<ForInst>(&inst))
|
||||
return getInstAtPosition(positions, level + 1, childForInst->getBody());
|
||||
|
||||
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
|
||||
auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
|
||||
if (ret != nullptr)
|
||||
return ret;
|
||||
if (auto *elseClause = ifInst->getElse())
|
||||
return getInstAtPosition(positions, level + 1, elseClause);
|
||||
}
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
for (auto &blockList : opInst->getBlockLists()) {
|
||||
for (auto &b : blockList)
|
||||
if (auto *ret = getInstAtPosition(positions, level + 1, &b))
|
||||
return ret;
|
||||
}
|
||||
return nullptr;
|
||||
for (auto &blockList : cast<OperationInst>(&inst)->getBlockLists()) {
|
||||
for (auto &b : blockList)
|
||||
if (auto *ret = getInstAtPosition(positions, level + 1, &b))
|
||||
return ret;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -73,7 +73,6 @@ public:
|
|||
bool verifyBlock(const Block &block, bool isTopLevel);
|
||||
bool verifyOperation(const OperationInst &op);
|
||||
bool verifyForInst(const ForInst &forInst);
|
||||
bool verifyIfInst(const IfInst &ifInst);
|
||||
bool verifyDominance(const Block &block);
|
||||
bool verifyInstDominance(const Instruction &inst);
|
||||
|
||||
|
@ -180,10 +179,6 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) {
|
|||
if (verifyForInst(cast<ForInst>(inst)))
|
||||
return true;
|
||||
break;
|
||||
case Instruction::Kind::If:
|
||||
if (verifyIfInst(cast<IfInst>(inst)))
|
||||
return true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,18 +245,6 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) {
|
|||
return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false);
|
||||
}
|
||||
|
||||
bool FuncVerifier::verifyIfInst(const IfInst &ifInst) {
|
||||
// TODO: check that if conditions are properly formed.
|
||||
if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false))
|
||||
return true;
|
||||
|
||||
if (auto *elseClause = ifInst.getElse())
|
||||
if (verifyBlock(*elseClause, /*isTopLevel*/ false))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FuncVerifier::verifyDominance(const Block &block) {
|
||||
for (auto &inst : block) {
|
||||
// Check that all operands on the instruction are ok.
|
||||
|
@ -283,14 +266,6 @@ bool FuncVerifier::verifyDominance(const Block &block) {
|
|||
if (verifyDominance(*cast<ForInst>(inst).getBody()))
|
||||
return true;
|
||||
break;
|
||||
case Instruction::Kind::If:
|
||||
auto &ifInst = cast<IfInst>(inst);
|
||||
if (verifyDominance(*ifInst.getThen()))
|
||||
return true;
|
||||
if (auto *elseClause = ifInst.getElse())
|
||||
if (verifyDominance(*elseClause))
|
||||
return true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -145,7 +145,6 @@ private:
|
|||
// Visit functions.
|
||||
void visitInstruction(const Instruction *inst);
|
||||
void visitForInst(const ForInst *forInst);
|
||||
void visitIfInst(const IfInst *ifInst);
|
||||
void visitOperationInst(const OperationInst *opInst);
|
||||
void visitType(Type type);
|
||||
void visitAttribute(Attribute attr);
|
||||
|
@ -197,10 +196,6 @@ void ModuleState::visitAttribute(Attribute attr) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitIfInst(const IfInst *ifInst) {
|
||||
recordIntegerSetReference(ifInst->getIntegerSet());
|
||||
}
|
||||
|
||||
void ModuleState::visitForInst(const ForInst *forInst) {
|
||||
AffineMap lbMap = forInst->getLowerBoundMap();
|
||||
if (!hasCustomForm(lbMap))
|
||||
|
@ -225,8 +220,6 @@ void ModuleState::visitOperationInst(const OperationInst *op) {
|
|||
|
||||
void ModuleState::visitInstruction(const Instruction *inst) {
|
||||
switch (inst->getKind()) {
|
||||
case Instruction::Kind::If:
|
||||
return visitIfInst(cast<IfInst>(inst));
|
||||
case Instruction::Kind::For:
|
||||
return visitForInst(cast<ForInst>(inst));
|
||||
case Instruction::Kind::OperationInst:
|
||||
|
@ -1077,7 +1070,6 @@ public:
|
|||
void print(const Instruction *inst);
|
||||
void print(const OperationInst *inst);
|
||||
void print(const ForInst *inst);
|
||||
void print(const IfInst *inst);
|
||||
void print(const Block *block, bool printBlockArgs = true);
|
||||
|
||||
void printOperation(const OperationInst *op);
|
||||
|
@ -1125,6 +1117,9 @@ public:
|
|||
unsigned index) override;
|
||||
|
||||
/// Print a block list.
|
||||
void printBlockList(const BlockList &blocks) override {
|
||||
printBlockList(blocks, /*printEntryBlockArgs=*/true);
|
||||
}
|
||||
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
|
||||
os << " {\n";
|
||||
if (!blocks.empty()) {
|
||||
|
@ -1214,12 +1209,6 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
|
|||
// Recursively number the stuff in the body.
|
||||
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
|
||||
break;
|
||||
case Instruction::Kind::If: {
|
||||
auto *ifInst = cast<IfInst>(&inst);
|
||||
numberValuesInBlock(*ifInst->getThen());
|
||||
if (auto *elseBlock = ifInst->getElse())
|
||||
numberValuesInBlock(*elseBlock);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1360,8 +1349,7 @@ void FunctionPrinter::printFunctionSignature() {
|
|||
}
|
||||
|
||||
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
|
||||
// Print the block label and argument list, unless this is the first block of
|
||||
// the function, or the first block of an IfInst/ForInst with no arguments.
|
||||
// Print the block label and argument list if requested.
|
||||
if (printBlockArgs) {
|
||||
os.indent(currentIndent);
|
||||
printBlockName(block);
|
||||
|
@ -1418,8 +1406,6 @@ void FunctionPrinter::print(const Instruction *inst) {
|
|||
return print(cast<OperationInst>(inst));
|
||||
case Instruction::Kind::For:
|
||||
return print(cast<ForInst>(inst));
|
||||
case Instruction::Kind::If:
|
||||
return print(cast<IfInst>(inst));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1447,22 +1433,6 @@ void FunctionPrinter::print(const ForInst *inst) {
|
|||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
|
||||
void FunctionPrinter::print(const IfInst *inst) {
|
||||
os.indent(currentIndent) << "if ";
|
||||
IntegerSet set = inst->getIntegerSet();
|
||||
printIntegerSetReference(set);
|
||||
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
|
||||
printTrailingLocation(inst->getLoc());
|
||||
os << " {\n";
|
||||
print(inst->getThen(), /*printBlockArgs=*/false);
|
||||
os.indent(currentIndent) << "}";
|
||||
if (inst->hasElse()) {
|
||||
os << " else {\n";
|
||||
print(inst->getElse(), /*printBlockArgs=*/false);
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
}
|
||||
|
||||
void FunctionPrinter::printValueID(const Value *value,
|
||||
bool printResultNo) const {
|
||||
int resultNo = -1;
|
||||
|
|
|
@ -327,10 +327,3 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
|||
auto ubMap = AffineMap::getConstantMap(ub, context);
|
||||
return createFor(location, {}, lbMap, {}, ubMap, step);
|
||||
}
|
||||
|
||||
IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
auto *inst = IfInst::create(location, operands, set);
|
||||
block->getInstructions().insert(insertPoint, inst);
|
||||
return inst;
|
||||
}
|
||||
|
|
|
@ -73,9 +73,6 @@ void Instruction::destroy() {
|
|||
case Kind::For:
|
||||
delete cast<ForInst>(this);
|
||||
break;
|
||||
case Kind::If:
|
||||
delete cast<IfInst>(this);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -141,8 +138,6 @@ unsigned Instruction::getNumOperands() const {
|
|||
return cast<OperationInst>(this)->getNumOperands();
|
||||
case Kind::For:
|
||||
return cast<ForInst>(this)->getNumOperands();
|
||||
case Kind::If:
|
||||
return cast<IfInst>(this)->getNumOperands();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,8 +147,6 @@ MutableArrayRef<InstOperand> Instruction::getInstOperands() {
|
|||
return cast<OperationInst>(this)->getInstOperands();
|
||||
case Kind::For:
|
||||
return cast<ForInst>(this)->getInstOperands();
|
||||
case Kind::If:
|
||||
return cast<IfInst>(this)->getInstOperands();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -287,15 +280,6 @@ void Instruction::dropAllReferences() {
|
|||
// Make sure to drop references held by instructions within the body.
|
||||
cast<ForInst>(this)->getBody()->dropAllReferences();
|
||||
break;
|
||||
case Kind::If: {
|
||||
// Make sure to drop references held by instructions within the 'then' and
|
||||
// 'else' blocks.
|
||||
auto *ifInst = cast<IfInst>(this);
|
||||
ifInst->getThen()->dropAllReferences();
|
||||
if (auto *elseBlock = ifInst->getElse())
|
||||
elseBlock->dropAllReferences();
|
||||
break;
|
||||
}
|
||||
case Kind::OperationInst: {
|
||||
auto *opInst = cast<OperationInst>(this);
|
||||
if (isTerminator())
|
||||
|
@ -809,54 +793,6 @@ mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
|
|||
results.push_back(forInst->getInductionVar());
|
||||
return results;
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfInst
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
|
||||
: Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
|
||||
set(set) {
|
||||
operands.reserve(numOperands);
|
||||
|
||||
// The then of an 'if' inst always has one block.
|
||||
thenClause.push_back(new Block());
|
||||
}
|
||||
|
||||
IfInst::~IfInst() {
|
||||
if (elseClause)
|
||||
delete elseClause;
|
||||
|
||||
// An IfInst's IntegerSet 'set' should not be deleted since it is
|
||||
// allocated through MLIRContext's bump pointer allocator.
|
||||
}
|
||||
|
||||
IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
unsigned numOperands = operands.size();
|
||||
assert(numOperands == set.getNumOperands() &&
|
||||
"operand cound does not match the integer set operand count");
|
||||
|
||||
IfInst *inst = new IfInst(location, numOperands, set);
|
||||
|
||||
for (auto *op : operands)
|
||||
inst->operands.emplace_back(InstOperand(inst, op));
|
||||
|
||||
return inst;
|
||||
}
|
||||
|
||||
const AffineCondition IfInst::getCondition() const {
|
||||
return AffineCondition(*this, set);
|
||||
}
|
||||
|
||||
MLIRContext *IfInst::getContext() const {
|
||||
// Check for degenerate case of if instruction with no operands.
|
||||
// This is unlikely, but legal.
|
||||
if (operands.empty())
|
||||
return getFunction()->getContext();
|
||||
|
||||
return getOperand(0)->getType().getContext();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Instruction Cloning
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -931,40 +867,23 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
|
|||
for (auto *opValue : getOperands())
|
||||
operands.push_back(mapper.lookupOrDefault(const_cast<Value *>(opValue)));
|
||||
|
||||
if (auto *forInst = dyn_cast<ForInst>(this)) {
|
||||
auto lbMap = forInst->getLowerBoundMap();
|
||||
auto ubMap = forInst->getUpperBoundMap();
|
||||
// Otherwise, this must be a ForInst.
|
||||
auto *forInst = cast<ForInst>(this);
|
||||
auto lbMap = forInst->getLowerBoundMap();
|
||||
auto ubMap = forInst->getUpperBoundMap();
|
||||
|
||||
auto *newFor = ForInst::create(
|
||||
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
|
||||
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
|
||||
ubMap, forInst->getStep());
|
||||
auto *newFor = ForInst::create(
|
||||
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
|
||||
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()), ubMap,
|
||||
forInst->getStep());
|
||||
|
||||
// Remember the induction variable mapping.
|
||||
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
|
||||
// Remember the induction variable mapping.
|
||||
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
|
||||
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subInst : *forInst->getBody())
|
||||
newFor->getBody()->push_back(subInst.clone(mapper, context));
|
||||
|
||||
return newFor;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an If instruction.
|
||||
auto *ifInst = cast<IfInst>(this);
|
||||
auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
|
||||
|
||||
auto *resultThen = newIf->getThen();
|
||||
for (auto &childInst : *ifInst->getThen())
|
||||
resultThen->push_back(childInst.clone(mapper, context));
|
||||
|
||||
if (ifInst->hasElse()) {
|
||||
auto *resultElse = newIf->createElse();
|
||||
for (auto &childInst : *ifInst->getElse())
|
||||
resultElse->push_back(childInst.clone(mapper, context));
|
||||
}
|
||||
|
||||
return newIf;
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subInst : *forInst->getBody())
|
||||
newFor->getBody()->push_back(subInst.clone(mapper, context));
|
||||
return newFor;
|
||||
}
|
||||
|
||||
Instruction *Instruction::clone(MLIRContext *context) const {
|
||||
|
|
|
@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
|
|||
if (!block || &block->back() != op)
|
||||
return op->emitOpError("must be the last instruction in the parent block");
|
||||
|
||||
// Terminators may not exist in ForInst and IfInst.
|
||||
// TODO(riverriddle) Terminators may not exist with an operation region.
|
||||
if (block->getContainingInst())
|
||||
return op->emitOpError("may only be at the top level of a function");
|
||||
|
||||
|
|
|
@ -66,8 +66,6 @@ MLIRContext *IROperandOwner::getContext() const {
|
|||
return cast<OperationInst>(this)->getContext();
|
||||
case Kind::ForInst:
|
||||
return cast<ForInst>(this)->getContext();
|
||||
case Kind::IfInst:
|
||||
return cast<IfInst>(this)->getContext();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -996,8 +996,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
AffineMap map;
|
||||
IntegerSet set;
|
||||
if (parseAffineMapOrIntegerSetReference(map, set))
|
||||
return (emitError("expected affine map or integer set attribute value"),
|
||||
nullptr);
|
||||
return nullptr;
|
||||
if (map)
|
||||
return builder.getAffineMapAttr(map);
|
||||
assert(set);
|
||||
|
@ -2209,8 +2208,6 @@ public:
|
|||
const char *affineStructName);
|
||||
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
|
||||
bool isLower);
|
||||
ParseResult parseIfInst();
|
||||
ParseResult parseElseClause(Block *elseClause);
|
||||
ParseResult parseInstructions(Block *block);
|
||||
|
||||
private:
|
||||
|
@ -2392,10 +2389,6 @@ ParseResult FunctionParser::parseBlockBody(Block *block) {
|
|||
if (parseForInst())
|
||||
return ParseFailure;
|
||||
break;
|
||||
case Token::kw_if:
|
||||
if (parseIfInst())
|
||||
return ParseFailure;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2935,12 +2928,18 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Parse a keyword followed by a type.
|
||||
bool parseKeywordType(const char *keyword, Type &result) override {
|
||||
if (parser.getTokenSpelling() != keyword)
|
||||
return parser.emitError("expected '" + Twine(keyword) + "'");
|
||||
parser.consumeToken();
|
||||
return !(result = parser.parseType());
|
||||
/// Parse an optional keyword.
|
||||
bool parseOptionalKeyword(const char *keyword) override {
|
||||
// Check that the current token is a bare identifier or keyword.
|
||||
if (parser.getToken().isNot(Token::bare_identifier) &&
|
||||
!parser.getToken().isKeyword())
|
||||
return true;
|
||||
|
||||
if (parser.getTokenSpelling() == keyword) {
|
||||
parser.consumeToken();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Parse an arbitrary attribute of a given type and return it in result. This
|
||||
|
@ -3078,6 +3077,15 @@ public:
|
|||
return result == nullptr;
|
||||
}
|
||||
|
||||
/// Parses a list of blocks.
|
||||
bool parseBlockList() override {
|
||||
SmallVector<Block *, 2> results;
|
||||
if (parser.parseOperationBlockList(results))
|
||||
return true;
|
||||
parsedBlockLists.emplace_back(results);
|
||||
return false;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Methods for interacting with the parser
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -3099,6 +3107,11 @@ public:
|
|||
|
||||
/// Emit a diagnostic at the specified location and return true.
|
||||
bool emitError(llvm::SMLoc loc, const Twine &message) override {
|
||||
// If we emit an error, then cleanup any parsed block lists.
|
||||
for (auto &blockList : parsedBlockLists)
|
||||
parser.cleanupInvalidBlocks(blockList);
|
||||
parsedBlockLists.clear();
|
||||
|
||||
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
|
||||
emittedError = true;
|
||||
return true;
|
||||
|
@ -3106,7 +3119,13 @@ public:
|
|||
|
||||
bool didEmitError() const { return emittedError; }
|
||||
|
||||
/// Returns the block lists that were parsed.
|
||||
MutableArrayRef<SmallVector<Block *, 2>> getParsedBlockLists() {
|
||||
return parsedBlockLists;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<SmallVector<Block *, 2>> parsedBlockLists;
|
||||
SMLoc nameLoc;
|
||||
StringRef opName;
|
||||
FunctionParser &parser;
|
||||
|
@ -3145,8 +3164,25 @@ OperationInst *FunctionParser::parseCustomOperation() {
|
|||
if (opAsmParser.didEmitError())
|
||||
return nullptr;
|
||||
|
||||
// Check that enough block lists were reserved for those that were parsed.
|
||||
auto parsedBlockLists = opAsmParser.getParsedBlockLists();
|
||||
if (parsedBlockLists.size() > opState.numBlockLists) {
|
||||
opAsmParser.emitError(
|
||||
opLoc,
|
||||
"parsed more block lists than those reserved in the operation state");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Otherwise, we succeeded. Use the state it parsed as our op information.
|
||||
return builder.createOperation(opState);
|
||||
auto *opInst = builder.createOperation(opState);
|
||||
|
||||
// Resolve any parsed block lists.
|
||||
for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) {
|
||||
auto &opBlockList = opInst->getBlockList(i).getBlocks();
|
||||
opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(),
|
||||
parsedBlockLists[i].end());
|
||||
}
|
||||
return opInst;
|
||||
}
|
||||
|
||||
/// For instruction.
|
||||
|
@ -3438,69 +3474,6 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
|
|||
return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
|
||||
}
|
||||
|
||||
/// If instruction.
|
||||
///
|
||||
/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}`
|
||||
/// | ml-if-head `else` `if` ml-if-cond trailing-location?
|
||||
/// `{` inst* `}`
|
||||
/// ml-if-inst ::= ml-if-head
|
||||
/// | ml-if-head `else` `{` inst* `}`
|
||||
///
|
||||
ParseResult FunctionParser::parseIfInst() {
|
||||
auto loc = getToken().getLoc();
|
||||
consumeToken(Token::kw_if);
|
||||
|
||||
IntegerSet set = parseIntegerSetReference();
|
||||
if (!set)
|
||||
return ParseFailure;
|
||||
|
||||
SmallVector<Value *, 4> operands;
|
||||
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
|
||||
"integer set"))
|
||||
return ParseFailure;
|
||||
|
||||
IfInst *ifInst =
|
||||
builder.createIf(getEncodedSourceLocation(loc), operands, set);
|
||||
|
||||
// Try to parse the optional trailing location.
|
||||
if (parseOptionalTrailingLocation(ifInst))
|
||||
return ParseFailure;
|
||||
|
||||
Block *thenClause = ifInst->getThen();
|
||||
|
||||
// When parsing of an if instruction body fails, the IR contains
|
||||
// the if instruction with the portion of the body that has been
|
||||
// successfully parsed.
|
||||
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
|
||||
parseBlock(thenClause) ||
|
||||
parseToken(Token::r_brace, "expected '}' after instruction list"))
|
||||
return ParseFailure;
|
||||
|
||||
if (consumeIf(Token::kw_else)) {
|
||||
auto *elseClause = ifInst->createElse();
|
||||
if (parseElseClause(elseClause))
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPointToEnd(ifInst->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
ParseResult FunctionParser::parseElseClause(Block *elseClause) {
|
||||
if (getToken().is(Token::kw_if)) {
|
||||
builder.setInsertionPointToEnd(elseClause);
|
||||
return parseIfInst();
|
||||
}
|
||||
|
||||
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
|
||||
parseBlock(elseClause) ||
|
||||
parseToken(Token::r_brace, "expected '}' after instruction list"))
|
||||
return ParseFailure;
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Top-level entity parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -91,7 +91,6 @@ TOK_KEYWORD(attributes)
|
|||
TOK_KEYWORD(bf16)
|
||||
TOK_KEYWORD(ceildiv)
|
||||
TOK_KEYWORD(dense)
|
||||
TOK_KEYWORD(else)
|
||||
TOK_KEYWORD(splat)
|
||||
TOK_KEYWORD(f16)
|
||||
TOK_KEYWORD(f32)
|
||||
|
@ -100,7 +99,6 @@ TOK_KEYWORD(false)
|
|||
TOK_KEYWORD(floordiv)
|
||||
TOK_KEYWORD(for)
|
||||
TOK_KEYWORD(func)
|
||||
TOK_KEYWORD(if)
|
||||
TOK_KEYWORD(index)
|
||||
TOK_KEYWORD(loc)
|
||||
TOK_KEYWORD(max)
|
||||
|
|
|
@ -188,16 +188,6 @@ void CSE::simplifyBlock(Block *bb) {
|
|||
simplifyBlock(cast<ForInst>(i).getBody());
|
||||
break;
|
||||
}
|
||||
case Instruction::Kind::If: {
|
||||
auto &ifInst = cast<IfInst>(i);
|
||||
if (auto *elseBlock = ifInst.getElse()) {
|
||||
ScopedMapTy::ScopeTy scope(knownValues);
|
||||
simplifyBlock(elseBlock);
|
||||
}
|
||||
ScopedMapTy::ScopeTy scope(knownValues);
|
||||
simplifyBlock(ifInst.getThen());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
|
@ -99,16 +100,16 @@ public:
|
|||
SmallVector<ForInst *, 4> forInsts;
|
||||
SmallVector<OperationInst *, 4> loadOpInsts;
|
||||
SmallVector<OperationInst *, 4> storeOpInsts;
|
||||
bool hasIfInst = false;
|
||||
bool hasNonForRegion = false;
|
||||
|
||||
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
|
||||
|
||||
void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
|
||||
|
||||
void visitOperationInst(OperationInst *opInst) {
|
||||
if (opInst->isa<LoadOp>())
|
||||
if (opInst->getNumBlockLists() != 0)
|
||||
hasNonForRegion = true;
|
||||
else if (opInst->isa<LoadOp>())
|
||||
loadOpInsts.push_back(opInst);
|
||||
if (opInst->isa<StoreOp>())
|
||||
else if (opInst->isa<StoreOp>())
|
||||
storeOpInsts.push_back(opInst);
|
||||
}
|
||||
};
|
||||
|
@ -410,8 +411,8 @@ bool MemRefDependenceGraph::init(Function *f) {
|
|||
// all loads and store accesses it contains.
|
||||
LoopNestStateCollector collector;
|
||||
collector.walkForInst(forInst);
|
||||
// Return false if IfInsts are found (not currently supported).
|
||||
if (collector.hasIfInst)
|
||||
// Return false if a non 'for' region was found (not currently supported).
|
||||
if (collector.hasNonForRegion)
|
||||
return false;
|
||||
Node node(id++, &inst);
|
||||
for (auto *opInst : collector.loadOpInsts) {
|
||||
|
@ -434,19 +435,18 @@ bool MemRefDependenceGraph::init(Function *f) {
|
|||
auto *memref = opInst->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
// Create graph node for top-level store op.
|
||||
Node node(id++, &inst);
|
||||
node.stores.push_back(opInst);
|
||||
auto *memref = opInst->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
} else if (opInst->getNumBlockLists() != 0) {
|
||||
// Return false if another region is found (not currently supported).
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Return false if IfInsts are found (not currently supported).
|
||||
if (isa<IfInst>(&inst))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Walk memref access lists and add graph edges between dependent nodes.
|
||||
|
|
|
@ -119,15 +119,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool walkIfInstPostOrder(IfInst *ifInst) {
|
||||
bool hasInnerLoops =
|
||||
walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
|
||||
if (ifInst->hasElse())
|
||||
hasInnerLoops |=
|
||||
walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool walkOpInstPostOrder(OperationInst *opInst) {
|
||||
for (auto &blockList : opInst->getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -246,7 +247,7 @@ public:
|
|||
PassResult runOnFunction(Function *function) override;
|
||||
|
||||
bool lowerForInst(ForInst *forInst);
|
||||
bool lowerIfInst(IfInst *ifInst);
|
||||
bool lowerAffineIf(AffineIfOp *ifOp);
|
||||
bool lowerAffineApply(AffineApplyOp *op);
|
||||
|
||||
static char passID;
|
||||
|
@ -409,7 +410,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
|
|||
// enabling easy nesting of "if" instructions and if-then-else-if chains.
|
||||
//
|
||||
// +--------------------------------+
|
||||
// | <code before the IfInst> |
|
||||
// | <code before the AffineIfOp> |
|
||||
// | %zero = constant 0 : index |
|
||||
// | %v = affine_apply #expr1(%ops) |
|
||||
// | %c = cmpi "sge" %v, %zero |
|
||||
|
@ -453,10 +454,11 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
|
|||
// v v
|
||||
// +--------------------------------+
|
||||
// | continue: |
|
||||
// | <code after the IfInst> |
|
||||
// | <code after the AffineIfOp> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
|
||||
bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) {
|
||||
auto *ifInst = ifOp->getInstruction();
|
||||
auto loc = ifInst->getLoc();
|
||||
|
||||
// Start by splitting the block containing the 'if' into two parts. The part
|
||||
|
@ -466,22 +468,38 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
|
|||
auto *continueBlock = condBlock->splitBlock(ifInst);
|
||||
|
||||
// Create a block for the 'then' code, inserting it between the cond and
|
||||
// continue blocks. Move the instructions over from the IfInst and add a
|
||||
// continue blocks. Move the instructions over from the AffineIfOp and add a
|
||||
// branch to the continuation point.
|
||||
Block *thenBlock = new Block();
|
||||
thenBlock->insertBefore(continueBlock);
|
||||
|
||||
auto *oldThen = ifInst->getThen();
|
||||
thenBlock->getInstructions().splice(thenBlock->begin(),
|
||||
oldThen->getInstructions(),
|
||||
oldThen->begin(), oldThen->end());
|
||||
// If the 'then' block is not empty, then splice the instructions.
|
||||
auto &oldThenBlocks = ifOp->getThenBlocks();
|
||||
if (!oldThenBlocks.empty()) {
|
||||
// We currently only handle one 'then' block.
|
||||
if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
|
||||
return true;
|
||||
|
||||
Block *oldThen = &oldThenBlocks.front();
|
||||
|
||||
thenBlock->getInstructions().splice(thenBlock->begin(),
|
||||
oldThen->getInstructions(),
|
||||
oldThen->begin(), oldThen->end());
|
||||
}
|
||||
|
||||
FuncBuilder builder(thenBlock);
|
||||
builder.create<BranchOp>(loc, continueBlock);
|
||||
|
||||
// Handle the 'else' block the same way, but we skip it if we have no else
|
||||
// code.
|
||||
Block *elseBlock = continueBlock;
|
||||
if (auto *oldElse = ifInst->getElse()) {
|
||||
auto &oldElseBlocks = ifOp->getElseBlocks();
|
||||
if (!oldElseBlocks.empty()) {
|
||||
// We currently only handle one 'else' block.
|
||||
if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
|
||||
return true;
|
||||
|
||||
auto *oldElse = &oldElseBlocks.front();
|
||||
elseBlock = new Block();
|
||||
elseBlock->insertBefore(continueBlock);
|
||||
|
||||
|
@ -493,7 +511,7 @@ bool LowerAffinePass::lowerIfInst(IfInst *ifInst) {
|
|||
}
|
||||
|
||||
// Ok, now we just have to handle the condition logic.
|
||||
auto integerSet = ifInst->getCondition().getIntegerSet();
|
||||
auto integerSet = ifOp->getIntegerSet();
|
||||
|
||||
// Implement short-circuit logic. For each affine expression in the 'if'
|
||||
// condition, convert it into an affine map and call `affine_apply` to obtain
|
||||
|
@ -593,29 +611,30 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) {
|
|||
PassResult LowerAffinePass::runOnFunction(Function *function) {
|
||||
SmallVector<Instruction *, 8> instsToRewrite;
|
||||
|
||||
// Collect all the If and For instructions as well as AffineApplyOps. We do
|
||||
// this as a prepass to avoid invalidating the walker with our rewrite.
|
||||
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
|
||||
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
||||
function->walkInsts([&](Instruction *inst) {
|
||||
if (isa<IfInst>(inst) || isa<ForInst>(inst))
|
||||
if (isa<ForInst>(inst))
|
||||
instsToRewrite.push_back(inst);
|
||||
auto op = dyn_cast<OperationInst>(inst);
|
||||
if (op && op->isa<AffineApplyOp>())
|
||||
if (op && (op->isa<AffineApplyOp>() || op->isa<AffineIfOp>()))
|
||||
instsToRewrite.push_back(inst);
|
||||
});
|
||||
|
||||
// Rewrite all of the ifs and fors. We walked the instructions in preorder,
|
||||
// so we know that we will rewrite them in the same order.
|
||||
for (auto *inst : instsToRewrite)
|
||||
if (auto *ifInst = dyn_cast<IfInst>(inst)) {
|
||||
if (lowerIfInst(ifInst))
|
||||
return failure();
|
||||
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
|
||||
if (auto *forInst = dyn_cast<ForInst>(inst)) {
|
||||
if (lowerForInst(forInst))
|
||||
return failure();
|
||||
} else {
|
||||
auto op = cast<OperationInst>(inst);
|
||||
if (lowerAffineApply(op->cast<AffineApplyOp>()))
|
||||
if (auto ifOp = op->dyn_cast<AffineIfOp>()) {
|
||||
if (lowerAffineIf(ifOp))
|
||||
return failure();
|
||||
} else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
|
@ -559,9 +560,6 @@ static bool instantiateMaterialization(Instruction *inst,
|
|||
if (isa<ForInst>(inst))
|
||||
return inst->emitError("NYI path ForInst");
|
||||
|
||||
if (isa<IfInst>(inst))
|
||||
return inst->emitError("NYI path IfInst");
|
||||
|
||||
// Create a builder here for unroll-and-jam effects.
|
||||
FuncBuilder b(inst);
|
||||
auto *opInst = cast<OperationInst>(inst);
|
||||
|
@ -570,6 +568,9 @@ static bool instantiateMaterialization(Instruction *inst,
|
|||
if (opInst->isa<AffineApplyOp>()) {
|
||||
return false;
|
||||
}
|
||||
if (opInst->getNumBlockLists() != 0)
|
||||
return inst->emitError("NYI path Op with region");
|
||||
|
||||
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
|
||||
auto *clone = instantiate(&b, write, state->hwVectorType,
|
||||
state->hwVectorInstance, state->substitutionsMap);
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#define DEBUG_TYPE "simplify-affine-structure"
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::report_fatal_error;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -42,9 +41,6 @@ struct SimplifyAffineStructures : public FunctionPass {
|
|||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
void visitIfInst(IfInst *ifInst);
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
|
@ -66,28 +62,19 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
|
|||
return set;
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
|
||||
auto set = ifInst->getCondition().getIntegerSet();
|
||||
ifInst->setIntegerSet(simplifyIntegerSet(set));
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||
MutableAffineMap mMap(mapAttr.getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opInst->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
|
||||
f->walkInsts([&](Instruction *inst) {
|
||||
if (auto *opInst = dyn_cast<OperationInst>(inst))
|
||||
visitOperationInst(opInst);
|
||||
if (auto *ifInst = dyn_cast<IfInst>(inst))
|
||||
visitIfInst(ifInst);
|
||||
f->walkOps([&](OperationInst *opInst) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||
MutableAffineMap mMap(mapAttr.getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opInst->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
} else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>()) {
|
||||
auto simplified = simplifyIntegerSet(setAttr.getValue());
|
||||
opInst->setAttr(attr.first, IntegerSetAttr::get(simplified));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return success();
|
||||
|
|
|
@ -243,14 +243,6 @@ func @non_instruction() {
|
|||
|
||||
// -----
|
||||
|
||||
func @invalid_if_conditional1() {
|
||||
for %i = 1 to 10 {
|
||||
if () { // expected-error {{expected ':' or '['}}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_if_conditional2() {
|
||||
for %i = 1 to 10 {
|
||||
if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}}
|
||||
|
@ -664,7 +656,11 @@ func @invalid_if_operands2(%N : index) {
|
|||
func @invalid_if_operands3(%N : index) {
|
||||
for %i = 1 to 10 {
|
||||
if #set0(%i)[%i] {
|
||||
// expected-error@-1 {{value '%i' cannot be used as a symbol}}
|
||||
// expected-error@-1 {{operand cannot be used as a symbol}}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// expected-error@+1 {{expected '"' in string literal}}
|
||||
|
|
|
@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
|
|||
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
|
||||
}
|
||||
|
||||
// CHECK: ) loc(fused<"myPass">["foo", "foo2"])
|
||||
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
|
||||
}
|
||||
// CHECK: } loc(fused<"myPass">["foo", "foo2"])
|
||||
if #set0(%2) {
|
||||
} loc(fused<"myPass">["foo", "foo2"])
|
||||
|
||||
// CHECK: return %0 : i32 loc(unknown)
|
||||
return %1 : i32 loc(unknown)
|
||||
|
|
|
@ -287,13 +287,15 @@ func @ifinst(%N: index) {
|
|||
// CHECK: %c1_i32 = constant 1 : i32
|
||||
%y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32
|
||||
%z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32
|
||||
} else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) {
|
||||
// CHECK: %c1 = constant 1 : index
|
||||
%u = constant 1 : index
|
||||
// CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
|
||||
%w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
|
||||
} else { // CHECK } else {
|
||||
%v = constant 3 : i32 // %c3_i32 = constant 3 : i32
|
||||
} else { // CHECK } else {
|
||||
if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) {
|
||||
// CHECK: %c1 = constant 1 : index
|
||||
%u = constant 1 : index
|
||||
// CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1]
|
||||
%w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u]
|
||||
} else { // CHECK } else {
|
||||
%v = constant 3 : i32 // %c3_i32 = constant 3 : i32
|
||||
}
|
||||
} // CHECK }
|
||||
} // CHECK }
|
||||
return // CHECK return
|
||||
|
@ -751,11 +753,11 @@ func @type_alias() -> !i32_type_alias {
|
|||
func @verbose_if(%N: index) {
|
||||
%c = constant 200 : index
|
||||
|
||||
// CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () {
|
||||
"if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () {
|
||||
// CHECK: if #set0(%c200)[%arg0, %c200] {
|
||||
"if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () {
|
||||
// CHECK-NEXT: "add"
|
||||
%y = "add"(%c, %N) : (index, index) -> index
|
||||
// CHECK-NEXT: } {
|
||||
// CHECK-NEXT: } else {
|
||||
} { // The else block list.
|
||||
// CHECK-NEXT: "add"
|
||||
%z = "add"(%c, %c) : (index, index) -> index
|
||||
|
|
|
@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
|
|||
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
|
||||
}
|
||||
|
||||
// CHECK: ) <"myPass">["foo", "foo2"]
|
||||
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
|
||||
}
|
||||
// CHECK: } <"myPass">["foo", "foo2"]
|
||||
if #set0(%2) {
|
||||
} loc(fused<"myPass">["foo", "foo2"])
|
||||
|
||||
// CHECK: return %0 : i32 [unknown]
|
||||
return %1 : i32 loc(unknown)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() {
|
|||
%c0 = constant 4 : index
|
||||
if #set0(%c0) {
|
||||
}
|
||||
// Top-level IfInst should prevent fusion.
|
||||
// Top-level IfOp should prevent fusion.
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() {
|
|||
%v0 = load %m[%i1] : memref<10xf32>
|
||||
}
|
||||
|
||||
// IfInst in ForInst should prevent fusion.
|
||||
// IfOp in ForInst should prevent fusion.
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
|
|
@ -10,7 +10,7 @@ func @store_may_execute_before_load() {
|
|||
%cf7 = constant 7.0 : f32
|
||||
%c0 = constant 4 : index
|
||||
// There is a dependence from store 0 to load 1 at depth 1 because the
|
||||
// ancestor IfInst of the store, dominates the ancestor ForSmt of the load,
|
||||
// ancestor IfOp of the store, dominates the ancestor ForSmt of the load,
|
||||
// and thus the store "may" conditionally execute before the load.
|
||||
if #set0(%c0) {
|
||||
for %i0 = 0 to 10 {
|
||||
|
|
|
@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) {
|
|||
for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) {
|
||||
}
|
||||
|
||||
// CHECK: if #set0(%c4) loc(unknown)
|
||||
// CHECK: } loc(unknown)
|
||||
%2 = constant 4 : index
|
||||
if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) {
|
||||
}
|
||||
if #set0(%2) {
|
||||
} loc(fused<"myPass">["foo", "foo2"])
|
||||
|
||||
// CHECK: return %0 : i32 loc(unknown)
|
||||
return %1 : i32 loc("bar")
|
||||
|
|
Loading…
Reference in New Issue