NFC. Move all of the remaining operations left in BuiltinOps to StandardOps. The only thing left in BuiltinOps are the core MLIR types. The standard types can't be moved because they are referenced within the IR directory, e.g. in things like Builder.

PiperOrigin-RevId: 236403665
This commit is contained in:
River Riddle 2019-03-01 16:58:00 -08:00 committed by jpienaar
parent 85d9b6c8f7
commit f37651c708
38 changed files with 771 additions and 869 deletions

View File

@ -9,7 +9,6 @@
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/MLIREmitter.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Types.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/ExecutionEngine/ExecutionEngine.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/BuiltinOps.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/IR/Module.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/Target/LLVMIR.h"

View File

@ -1,357 +0,0 @@
//===- BuiltinOps.h - Builtin MLIR Operations -------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines convenience types for working with builtin operations
// in the MLIR instruction set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BUILTINOPS_H
#define MLIR_IR_BUILTINOPS_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
class Builder;
class BuiltinDialect : public Dialect {
public:
BuiltinDialect(MLIRContext *context);
};
/// The "br" operation represents a branch instruction in a CFG function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types for each successor must match the
/// arguments of the block successor. For example:
///
/// bb2:
/// %2 = call @someFn()
/// br bb3(%2 : tensor<*xf32>)
/// bb3(%3: tensor<*xf32>):
///
class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
static StringRef getOperationName() { return "br"; }
static void build(Builder *builder, OperationState *result, Block *dest,
ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
/// Return the block this branch jumps to.
Block *getDest();
const Block *getDest() const {
return const_cast<BranchOp *>(this)->getDest();
}
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
private:
friend class Instruction;
explicit BranchOp(const Instruction *state) : Op(state) {}
};
/// The "cond_br" operation represents a conditional branch instruction in a
/// CFG function. The operation takes variable number of operands and produces
/// no results. The operand number and types for each successor must match the
// arguments of the block successor. For example:
///
/// bb0:
/// %0 = extract_element %arg0[] : tensor<i1>
/// cond_br %0, bb1, bb2
/// bb1:
/// ...
/// bb2:
/// ...
///
class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
/// The operands list of a conditional branch operation is layed out as
/// follows:
/// { condition, [true_operands], [false_operands] }
public:
static StringRef getOperationName() { return "cond_br"; }
static void build(Builder *builder, OperationState *result, Value *condition,
Block *trueDest, ArrayRef<Value *> trueOperands,
Block *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
// The condition operand is the first operand in the list.
Value *getCondition() { return getOperand(0); }
const Value *getCondition() const { return getOperand(0); }
/// Return the destination if the condition is true.
Block *getTrueDest();
const Block *getTrueDest() const {
return const_cast<CondBranchOp *>(this)->getTrueDest();
}
/// Return the destination if the condition is false.
Block *getFalseDest();
const Block *getFalseDest() const {
return const_cast<CondBranchOp *>(this)->getFalseDest();
}
// Accessors for operands to the 'true' destination.
Value *getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
const Value *getTrueOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getTrueOperand(idx);
}
void setTrueOperand(unsigned idx, Value *value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
operand_iterator true_operand_begin() {
return operand_begin() + getTrueDestOperandIndex();
}
operand_iterator true_operand_end() {
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<operand_iterator> getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
const_operand_iterator true_operand_begin() const {
return operand_begin() + getTrueDestOperandIndex();
}
const_operand_iterator true_operand_end() const {
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
return {true_operand_begin(), true_operand_end()};
}
unsigned getNumTrueOperands() const;
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index);
// Accessors for operands to the 'false' destination.
Value *getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
const Value *getFalseOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
}
void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_iterator false_operand_begin() { return true_operand_end(); }
operand_iterator false_operand_end() {
return false_operand_begin() + getNumFalseOperands();
}
llvm::iterator_range<operand_iterator> getFalseOperands() {
return {false_operand_begin(), false_operand_end()};
}
const_operand_iterator false_operand_begin() const {
return true_operand_end();
}
const_operand_iterator false_operand_end() const {
return false_operand_begin() + getNumFalseOperands();
}
llvm::iterator_range<const_operand_iterator> getFalseOperands() const {
return {false_operand_begin(), false_operand_end()};
}
unsigned getNumFalseOperands() const;
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index);
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() const { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() const {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
friend class Instruction;
explicit CondBranchOp(const Instruction *state) : Op(state) {}
};
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
/// %1 = "constant"(){value: 42} : i32
/// %2 = "constant"(){value: @foo} : (f32)->f32
///
class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
/// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Type type,
Attribute value);
/// Builds a constant op with the specified attribute value and the
/// attribute's type.
static void build(Builder *builder, OperationState *result, Attribute value);
Attribute getValue() const { return getAttr("value"); }
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
protected:
friend class Instruction;
explicit ConstantOp(const Instruction *state) : Op(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
/// %1 = "constant"(){value: 42.0} : bf16
///
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type);
APFloat getValue() const {
return getAttrOfType<FloatAttr>("value").getValue();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of IntegerType.
///
/// %1 = "constant"(){value: 42} : i32
///
class ConstantIntOp : public ConstantOp {
public:
/// Build a constant int op producing an integer of the specified width.
static void build(Builder *builder, OperationState *result, int64_t value,
unsigned width);
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
static void build(Builder *builder, OperationState *result, int64_t value,
Type type);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of Index type.
///
/// %1 = "constant"(){value: 99} : () -> index
///
class ConstantIndexOp : public ConstantOp {
public:
/// Build a constant int op producing an index.
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {}
};
/// The "return" operation represents a return instruction within a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types must match the signature of the function
/// that contains the operation. For example:
///
/// mlfunc @foo() : (i32, f8) {
/// ...
/// return %0, %1 : i32, f8
///
class ReturnOp : public Op<ReturnOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
static StringRef getOperationName() { return "return"; }
static void build(Builder *builder, OperationState *result,
ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
private:
friend class Instruction;
explicit ReturnOp(const Instruction *state) : Op(state) {}
};
/// Prints dimension and symbol list.
void printDimAndSymbolList(Instruction::const_operand_iterator begin,
Instruction::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p);
/// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
} // end namespace mlir
#endif

View File

@ -268,9 +268,6 @@ public:
/// take O(N) where N is the number of instructions within the parent block.
bool isBeforeInBlock(const Instruction *other) const;
/// Check if this instruction is a return instruction.
bool isReturn() const;
void print(raw_ostream &os) const;
void dump() const;

View File

@ -25,7 +25,7 @@
#define MLIR_MATCHERS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include <type_traits>
@ -134,11 +134,6 @@ inline bool matchPattern(Value *value, const Pattern &pattern) {
return false;
}
/// Matches a ConstantIndexOp.
inline detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
return detail::op_matcher<ConstantIndexOp>();
}
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
inline detail::constant_int_op_binder

View File

@ -83,6 +83,43 @@ private:
explicit AllocOp(const Instruction *state) : Op(state) {}
};
/// The "br" operation represents a branch instruction in a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types for each successor must match the
/// arguments of the block successor. For example:
///
/// ^bb2:
/// %2 = call @someFn()
/// br ^bb3(%2 : tensor<*xf32>)
/// ^bb3(%3: tensor<*xf32>):
///
class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
static StringRef getOperationName() { return "br"; }
static void build(Builder *builder, OperationState *result, Block *dest,
ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
/// Return the block this branch jumps to.
Block *getDest();
const Block *getDest() const {
return const_cast<BranchOp *>(this)->getDest();
}
void setDest(Block *block);
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
private:
friend class Instruction;
explicit BranchOp(const Instruction *state) : Op(state) {}
};
/// The "call" operation represents a direct call to a function. The operands
/// and result types of the call must match the specified function type. The
/// callee is encoded as a function attribute named "callee".
@ -237,6 +274,248 @@ private:
explicit CmpIOp(const Instruction *state) : Op(state) {}
};
/// The "cond_br" operation represents a conditional branch instruction in a
/// function. The operation takes variable number of operands and produces
/// no results. The operand number and types for each successor must match the
// arguments of the block successor. For example:
///
/// ^bb0:
/// %0 = extract_element %arg0[] : tensor<i1>
/// cond_br %0, ^bb1, ^bb2
/// ^bb1:
/// ...
/// ^bb2:
/// ...
///
class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
// These are the indices into the dests list.
enum { trueIndex = 0, falseIndex = 1 };
/// The operands list of a conditional branch operation is layed out as
/// follows:
/// { condition, [true_operands], [false_operands] }
public:
static StringRef getOperationName() { return "cond_br"; }
static void build(Builder *builder, OperationState *result, Value *condition,
Block *trueDest, ArrayRef<Value *> trueOperands,
Block *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
// The condition operand is the first operand in the list.
Value *getCondition() { return getOperand(0); }
const Value *getCondition() const { return getOperand(0); }
/// Return the destination if the condition is true.
Block *getTrueDest();
const Block *getTrueDest() const {
return const_cast<CondBranchOp *>(this)->getTrueDest();
}
/// Return the destination if the condition is false.
Block *getFalseDest();
const Block *getFalseDest() const {
return const_cast<CondBranchOp *>(this)->getFalseDest();
}
// Accessors for operands to the 'true' destination.
Value *getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
const Value *getTrueOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getTrueOperand(idx);
}
void setTrueOperand(unsigned idx, Value *value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
operand_iterator true_operand_begin() {
return operand_begin() + getTrueDestOperandIndex();
}
operand_iterator true_operand_end() {
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<operand_iterator> getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
const_operand_iterator true_operand_begin() const {
return operand_begin() + getTrueDestOperandIndex();
}
const_operand_iterator true_operand_end() const {
return true_operand_begin() + getNumTrueOperands();
}
llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
return {true_operand_begin(), true_operand_end()};
}
unsigned getNumTrueOperands() const;
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index);
// Accessors for operands to the 'false' destination.
Value *getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
const Value *getFalseOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
}
void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
operand_iterator false_operand_begin() { return true_operand_end(); }
operand_iterator false_operand_end() {
return false_operand_begin() + getNumFalseOperands();
}
llvm::iterator_range<operand_iterator> getFalseOperands() {
return {false_operand_begin(), false_operand_end()};
}
const_operand_iterator false_operand_begin() const {
return true_operand_end();
}
const_operand_iterator false_operand_end() const {
return false_operand_begin() + getNumFalseOperands();
}
llvm::iterator_range<const_operand_iterator> getFalseOperands() const {
return {false_operand_begin(), false_operand_end()};
}
unsigned getNumFalseOperands() const;
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index);
private:
/// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() const { return 1; }
/// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() const {
return getTrueDestOperandIndex() + getNumTrueOperands();
}
friend class Instruction;
explicit CondBranchOp(const Instruction *state) : Op(state) {}
};
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
/// %1 = "constant"(){value: 42} : i32
/// %2 = "constant"(){value: @foo} : (f32)->f32
///
class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
/// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Type type,
Attribute value);
/// Builds a constant op with the specified attribute value and the
/// attribute's type.
static void build(Builder *builder, OperationState *result, Attribute value);
Attribute getValue() const { return getAttr("value"); }
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
protected:
friend class Instruction;
explicit ConstantOp(const Instruction *state) : Op(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
/// %1 = "constant"(){value: 42.0} : bf16
///
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type);
APFloat getValue() const {
return getAttrOfType<FloatAttr>("value").getValue();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantFloatOp(const Instruction *state) : ConstantOp(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of IntegerType.
///
/// %1 = "constant"(){value: 42} : i32
///
class ConstantIntOp : public ConstantOp {
public:
/// Build a constant int op producing an integer of the specified width.
static void build(Builder *builder, OperationState *result, int64_t value,
unsigned width);
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
static void build(Builder *builder, OperationState *result, int64_t value,
Type type);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantIntOp(const Instruction *state) : ConstantOp(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of Index type.
///
/// %1 = "constant"(){value: 99} : () -> index
///
class ConstantIndexOp : public ConstantOp {
public:
/// Build a constant int op producing an index.
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Instruction *op);
private:
friend class Instruction;
explicit ConstantIndexOp(const Instruction *state) : ConstantOp(state) {}
};
/// The "dealloc" operation frees the region of memory referenced by a memref
/// which was originally created by the "alloc" operation.
/// The "dealloc" operation should not be called on memrefs which alias an
@ -636,6 +915,33 @@ private:
explicit MemRefCastOp(const Instruction *state) : CastOp(state) {}
};
/// The "return" operation represents a return instruction within a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types must match the signature of the function
/// that contains the operation. For example:
///
/// mlfunc @foo() : (i32, f8) {
/// ...
/// return %0, %1 : i32, f8
///
class ReturnOp : public Op<ReturnOp, OpTrait::VariadicOperands,
OpTrait::ZeroResult, OpTrait::IsTerminator> {
public:
static StringRef getOperationName() { return "return"; }
static void build(Builder *builder, OperationState *result,
ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
private:
friend class Instruction;
explicit ReturnOp(const Instruction *state) : Op(state) {}
};
/// The "select" operation chooses one value based on a binary condition
/// supplied as its first operand. If the value of the first operand is 1, the
/// second operand is chosen, otherwise the third operand is chosen. The second
@ -749,6 +1055,16 @@ private:
explicit TensorCastOp(const Instruction *state) : CastOp(state) {}
};
/// Prints dimension and symbol list.
void printDimAndSymbolList(Instruction::const_operand_iterator begin,
Instruction::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p);
/// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
} // end namespace mlir
#endif // MLIR_STANDARDOPS_OPS_H

View File

@ -26,7 +26,7 @@
#define MLIR_TRANSFORMS_UTILS_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -18,7 +18,6 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"

View File

@ -26,7 +26,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/Ops.h"

View File

@ -23,9 +23,9 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallPtrSet.h"

View File

@ -27,7 +27,6 @@
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"

View File

@ -25,7 +25,6 @@
#include "mlir/Analysis/Passes.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/Support/Debug.h"

View File

@ -24,7 +24,6 @@
#include "mlir/Analysis/Passes.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/Support/Debug.h"

View File

@ -22,7 +22,6 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"

View File

@ -26,7 +26,6 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"

View File

@ -19,7 +19,6 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/StandardOps/Ops.h"

View File

@ -19,7 +19,6 @@
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"

View File

@ -26,7 +26,6 @@
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"

View File

@ -23,7 +23,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"

View File

@ -23,11 +23,12 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
@ -1205,22 +1206,25 @@ void FunctionPrinter::numberValueID(const Value *value) {
// Give constant integers special names.
if (auto *op = value->getDefiningInst()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
Attribute cst;
if (m_Constant(&cst).match(const_cast<Instruction *>(op))) {
Type type = op->getResult(0)->getType();
if (auto intCst = cst.dyn_cast<IntegerAttr>()) {
if (type.isIndex()) {
specialName << 'c' << intCst;
} else if (type.cast<IntegerType>().isInteger(1)) {
// i1 constants get special names.
if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
specialName << (intCst.getInt() ? "true" : "false");
} else {
specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
specialName << 'c' << intCst << '_' << type;
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
if (constant->getValue().isa<FunctionAttr>())
} else if (cst.isa<FunctionAttr>()) {
specialName << 'f';
else
} else {
specialName << "cst";
}
}
}
if (specialNameBuffer.empty()) {
switch (value->getKind()) {

View File

@ -1,454 +0,0 @@
//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
BuiltinDialect::BuiltinDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
}
void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin,
Instruction::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p) {
*p << '(';
p->printOperands(begin, begin + numDims);
*p << ')';
if (begin + numDims != end) {
*p << '[';
p->printOperands(begin + numDims, end);
*p << ']';
}
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
void BranchOp::build(Builder *builder, OperationState *result, Block *dest,
ArrayRef<Value *> operands) {
result->addSuccessor(dest, operands);
}
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
Block *dest;
SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
return false;
}
void BranchOp::print(OpAsmPrinter *p) const {
*p << "br ";
p->printSuccessorAndUseList(getInstruction(), 0);
}
Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); }
void BranchOp::setDest(Block *block) {
return getInstruction()->setSuccessor(block, 0);
}
void BranchOp::eraseOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(0, index);
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override {
auto condbr = op->cast<CondBranchOp>();
if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
return matchSuccess();
return matchFailure();
}
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
auto condbr = op->cast<CondBranchOp>();
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
// If the condition is known to evaluate to false we fold to a branch to the
// false destination. Otherwise, we fold to a branch to the true
// destination.
if (matchPattern(condbr->getCondition(), m_Zero())) {
foldedDest = condbr->getFalseDest();
branchArgs.assign(condbr->false_operand_begin(),
condbr->false_operand_end());
} else {
foldedDest = condbr->getTrueDest();
branchArgs.assign(condbr->true_operand_begin(),
condbr->true_operand_end());
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
}
};
} // end anonymous namespace.
void CondBranchOp::build(Builder *builder, OperationState *result,
Value *condition, Block *trueDest,
ArrayRef<Value *> trueOperands, Block *falseDest,
ArrayRef<Value *> falseOperands) {
result->addOperands(condition);
result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands);
}
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<Value *, 4> destOperands;
Block *dest;
OpAsmParser::OperandType condInfo;
// Parse the condition.
Type int1Ty = parser->getBuilder().getI1Type();
if (parser->parseOperand(condInfo) || parser->parseComma() ||
parser->resolveOperand(condInfo, int1Ty, result->operands)) {
return parser->emitError(parser->getNameLoc(),
"expected condition type was boolean (i1)");
}
// Parse the true successor.
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
// Parse the false successor.
destOperands.clear();
if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
// Return false on success.
return false;
}
void CondBranchOp::print(OpAsmPrinter *p) const {
*p << "cond_br ";
p->printOperand(getCondition());
*p << ", ";
p->printSuccessorAndUseList(getInstruction(), trueIndex);
*p << ", ";
p->printSuccessorAndUseList(getInstruction(), falseIndex);
}
bool CondBranchOp::verify() const {
if (!getCondition()->getType().isInteger(1))
return emitOpError("expected condition type was boolean (i1)");
return false;
}
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(std::make_unique<SimplifyConstCondBranchPred>(context));
}
Block *CondBranchOp::getTrueDest() {
return getInstruction()->getSuccessor(trueIndex);
}
Block *CondBranchOp::getFalseDest() {
return getInstruction()->getSuccessor(falseIndex);
}
unsigned CondBranchOp::getNumTrueOperands() const {
return getInstruction()->getNumSuccessorOperands(trueIndex);
}
void CondBranchOp::eraseTrueOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(trueIndex, index);
}
unsigned CondBranchOp::getNumFalseOperands() const {
return getInstruction()->getNumSuccessorOperands(falseIndex);
}
void CondBranchOp::eraseFalseOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(falseIndex, index);
}
//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result, Type type,
Attribute value) {
result->addAttribute("value", value);
result->types.push_back(type);
}
// Extracts and returns a type of an attribute if it has one. Returns a null
// type otherwise. Currently, NumericAttrs and FunctionAttrs have types.
static Type getAttributeType(Attribute attr) {
assert(attr && "expected non-null attribute");
if (auto numericAttr = attr.dyn_cast<NumericAttr>())
return numericAttr.getType();
if (auto functionAttr = attr.dyn_cast<FunctionAttr>())
return functionAttr.getType();
return {};
}
/// Builds a constant with the specified attribute value and type extracted
/// from the attribute. The attribute must have a type.
void ConstantOp::build(Builder *builder, OperationState *result,
Attribute value) {
Type t = getAttributeType(value);
assert(t && "expected an attribute with a type");
return build(builder, result, t, value);
}
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant ";
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"value"});
if (getAttrs().size() > 1)
*p << ' ';
*p << getValue();
if (!getValue().isa<FunctionAttr>())
*p << " : " << getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes))
return true;
// 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it.
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
type = intAttr.getType();
} else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) {
type = fpAttr.getType();
} else if (parser->parseColonType(type)) {
return true;
}
return parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
bool ConstantOp::verify() const {
auto value = getValue();
if (!value)
return emitOpError("requires a 'value' attribute");
auto type = this->getType();
if (type.isa<IntegerType>() || type.isIndex()) {
auto intAttr = value.dyn_cast<IntegerAttr>();
if (!intAttr)
return emitOpError(
"requires 'value' to be an integer for an integer result type");
// If the type has a known bitwidth we verify that the value can be
// represented with the given bitwidth.
if (!type.isIndex()) {
auto bitwidth = type.cast<IntegerType>().getWidth();
auto intVal = intAttr.getValue();
if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
return emitOpError("requires 'value' to be an integer within the range "
"of the integer result type");
}
return false;
}
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
}
auto attrType = getAttributeType(value);
if (!attrType)
return emitOpError("requires 'value' attribute to have a type");
if (attrType != type)
return emitOpError("requires the type of the 'value' attribute to match "
"that of the operation result");
return emitOpError(
"requires a result type that aligns with the 'value' attribute");
}
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.empty() && "constant has no operands");
return getValue();
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
}
bool ConstantFloatOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, unsigned width) {
Type type = builder->getIntegerType(width);
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, Type type) {
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState *result,
int64_t value) {
Type type = builder->getIndexType();
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
void ReturnOp::build(Builder *builder, OperationState *result,
ArrayRef<Value *> results) {
result->addOperands(results);
}
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands);
}
void ReturnOp::print(OpAsmPrinter *p) const {
*p << "return";
if (getNumOperands() > 0) {
*p << ' ';
p->printOperands(operand_begin(), operand_end());
*p << " : ";
interleave(
operand_begin(), operand_end(),
[&](const Value *e) { p->printType(e->getType()); },
[&]() { *p << ", "; });
}
}
bool ReturnOp::verify() const {
auto *function = getInstruction()->getFunction();
// The operand number and types must match the function signature.
const auto &results = function->getType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has " + Twine(getNumOperands()) +
" operands, but enclosing function returns " +
Twine(results.size()));
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i)->getType() != results[i])
return emitError("type of return operand " + Twine(i) +
" doesn't match function result type");
return false;
}

View File

@ -19,7 +19,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
@ -481,8 +481,6 @@ bool Instruction::use_empty() const {
return true;
}
bool Instruction::isReturn() const { return isa<ReturnOp>(); }
void Instruction::setSuccessor(Block *block, unsigned index) {
assert(index < getNumSuccessors());
getBlockOperands()[index].set(block);

View File

@ -25,7 +25,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
@ -45,6 +45,15 @@ using namespace mlir::detail;
using namespace llvm;
namespace {
/// A builtin dialect to define types/etc that are necessary for the
/// validity of the IR.
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) {
addTypes<FunctionType, UnknownType, FloatType, IndexType, IntegerType,
VectorType, RankedTensorType, UnrankedTensorType, MemRefType>();
}
};
struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.

View File

@ -21,7 +21,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -19,7 +19,6 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -37,14 +36,57 @@ using namespace mlir;
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AllocOp, CallOp, CallIndirectOp, CmpIOp, DeallocOp, DimOp,
DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp, MemRefCastOp,
SelectOp, StoreOp, TensorCastOp,
addOperations<AllocOp, BranchOp, CallOp, CallIndirectOp, CmpIOp, CondBranchOp,
ConstantOp, DeallocOp, DimOp, DmaStartOp, DmaWaitOp,
ExtractElementOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
StoreOp, TensorCastOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.inc"
>();
}
void mlir::printDimAndSymbolList(Instruction::const_operand_iterator begin,
Instruction::const_operand_iterator end,
unsigned numDims, OpAsmPrinter *p) {
*p << '(';
p->printOperands(begin, begin + numDims);
*p << ')';
if (begin + numDims != end) {
*p << '[';
p->printOperands(begin + numDims, end);
*p << ']';
}
}
// Parses dimension and symbol list, and sets 'numDims' to the number of
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
}
/// Matches a ConstantIndexOp.
/// TODO: This should probably just be a general matcher that uses m_Constant
/// and checks the operation for an index type.
static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
return detail::op_matcher<ConstantIndexOp>();
}
//===----------------------------------------------------------------------===//
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
@ -310,6 +352,39 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.push_back(std::make_unique<SimplifyDeadAlloc>(context));
}
//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
void BranchOp::build(Builder *builder, OperationState *result, Block *dest,
ArrayRef<Value *> operands) {
result->addSuccessor(dest, operands);
}
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
Block *dest;
SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
return false;
}
void BranchOp::print(OpAsmPrinter *p) const {
*p << "br ";
p->printSuccessorAndUseList(getInstruction(), 0);
}
Block *BranchOp::getDest() { return getInstruction()->getSuccessor(0); }
void BranchOp::setDest(Block *block) {
return getInstruction()->setSuccessor(block, 0);
}
void BranchOp::eraseOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(0, index);
}
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
@ -692,6 +767,300 @@ Attribute CmpIOp::constantFold(ArrayRef<Attribute> operands,
return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val));
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
namespace {
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
struct SimplifyConstCondBranchPred : public RewritePattern {
SimplifyConstCondBranchPred(MLIRContext *context)
: RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override {
auto condbr = op->cast<CondBranchOp>();
if (matchPattern(condbr->getCondition(), m_Op<ConstantOp>()))
return matchSuccess();
return matchFailure();
}
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
auto condbr = op->cast<CondBranchOp>();
Block *foldedDest;
SmallVector<Value *, 4> branchArgs;
// If the condition is known to evaluate to false we fold to a branch to the
// false destination. Otherwise, we fold to a branch to the true
// destination.
if (matchPattern(condbr->getCondition(), m_Zero())) {
foldedDest = condbr->getFalseDest();
branchArgs.assign(condbr->false_operand_begin(),
condbr->false_operand_end());
} else {
foldedDest = condbr->getTrueDest();
branchArgs.assign(condbr->true_operand_begin(),
condbr->true_operand_end());
}
rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
}
};
} // end anonymous namespace.
void CondBranchOp::build(Builder *builder, OperationState *result,
Value *condition, Block *trueDest,
ArrayRef<Value *> trueOperands, Block *falseDest,
ArrayRef<Value *> falseOperands) {
result->addOperands(condition);
result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands);
}
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<Value *, 4> destOperands;
Block *dest;
OpAsmParser::OperandType condInfo;
// Parse the condition.
Type int1Ty = parser->getBuilder().getI1Type();
if (parser->parseOperand(condInfo) || parser->parseComma() ||
parser->resolveOperand(condInfo, int1Ty, result->operands)) {
return parser->emitError(parser->getNameLoc(),
"expected condition type was boolean (i1)");
}
// Parse the true successor.
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
// Parse the false successor.
destOperands.clear();
if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
// Return false on success.
return false;
}
void CondBranchOp::print(OpAsmPrinter *p) const {
*p << "cond_br ";
p->printOperand(getCondition());
*p << ", ";
p->printSuccessorAndUseList(getInstruction(), trueIndex);
*p << ", ";
p->printSuccessorAndUseList(getInstruction(), falseIndex);
}
bool CondBranchOp::verify() const {
if (!getCondition()->getType().isInteger(1))
return emitOpError("expected condition type was boolean (i1)");
return false;
}
void CondBranchOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.push_back(std::make_unique<SimplifyConstCondBranchPred>(context));
}
Block *CondBranchOp::getTrueDest() {
return getInstruction()->getSuccessor(trueIndex);
}
Block *CondBranchOp::getFalseDest() {
return getInstruction()->getSuccessor(falseIndex);
}
unsigned CondBranchOp::getNumTrueOperands() const {
return getInstruction()->getNumSuccessorOperands(trueIndex);
}
void CondBranchOp::eraseTrueOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(trueIndex, index);
}
unsigned CondBranchOp::getNumFalseOperands() const {
return getInstruction()->getNumSuccessorOperands(falseIndex);
}
void CondBranchOp::eraseFalseOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(falseIndex, index);
}
//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result, Type type,
Attribute value) {
result->addAttribute("value", value);
result->types.push_back(type);
}
// Extracts and returns a type of an attribute if it has one. Returns a null
// type otherwise. Currently, NumericAttrs and FunctionAttrs have types.
static Type getAttributeType(Attribute attr) {
assert(attr && "expected non-null attribute");
if (auto numericAttr = attr.dyn_cast<NumericAttr>())
return numericAttr.getType();
if (auto functionAttr = attr.dyn_cast<FunctionAttr>())
return functionAttr.getType();
return {};
}
/// Builds a constant with the specified attribute value and type extracted
/// from the attribute. The attribute must have a type.
void ConstantOp::build(Builder *builder, OperationState *result,
Attribute value) {
Type t = getAttributeType(value);
assert(t && "expected an attribute with a type");
return build(builder, result, t, value);
}
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant ";
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"value"});
if (getAttrs().size() > 1)
*p << ' ';
*p << getValue();
if (!getValue().isa<FunctionAttr>())
*p << " : " << getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes))
return true;
// 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it.
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
type = intAttr.getType();
} else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) {
type = fpAttr.getType();
} else if (parser->parseColonType(type)) {
return true;
}
return parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
bool ConstantOp::verify() const {
auto value = getValue();
if (!value)
return emitOpError("requires a 'value' attribute");
auto type = this->getType();
if (type.isa<IntegerType>() || type.isIndex()) {
auto intAttr = value.dyn_cast<IntegerAttr>();
if (!intAttr)
return emitOpError(
"requires 'value' to be an integer for an integer result type");
// If the type has a known bitwidth we verify that the value can be
// represented with the given bitwidth.
if (!type.isIndex()) {
auto bitwidth = type.cast<IntegerType>().getWidth();
auto intVal = intAttr.getValue();
if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
return emitOpError("requires 'value' to be an integer within the range "
"of the integer result type");
}
return false;
}
if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
}
auto attrType = getAttributeType(value);
if (!attrType)
return emitOpError("requires 'value' attribute to have a type");
if (attrType != type)
return emitOpError("requires the type of the 'value' attribute to match "
"that of the operation result");
return emitOpError(
"requires a result type that aligns with the 'value' attribute");
}
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.empty() && "constant has no operands");
return getValue();
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
}
bool ConstantFloatOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) &&
op->getResult(0)->getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, unsigned width) {
Type type = builder->getIntegerType(width);
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, Type type) {
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Instruction *op) {
return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState *result,
int64_t value) {
Type type = builder->getIndexType();
ConstantOp::build(builder, result, type,
builder->getIntegerAttr(type, value));
}
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
@ -1380,6 +1749,55 @@ Attribute RemIUOp::constantFold(ArrayRef<Attribute> operands,
return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhs.getValue()));
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
void ReturnOp::build(Builder *builder, OperationState *result,
ArrayRef<Value *> results) {
result->addOperands(results);
}
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands);
}
void ReturnOp::print(OpAsmPrinter *p) const {
*p << "return";
if (getNumOperands() > 0) {
*p << ' ';
p->printOperands(operand_begin(), operand_end());
*p << " : ";
interleave(
operand_begin(), operand_end(),
[&](const Value *e) { p->printType(e->getType()); },
[&]() { *p << ", "; });
}
}
bool ReturnOp::verify() const {
auto *function = getInstruction()->getFunction();
// The operand number and types must match the function signature.
const auto &results = function->getType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has " + Twine(getNumOperands()) +
" operands, but enclosing function returns " +
Twine(results.size()));
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i)->getType() != results[i])
return emitError("type of return operand " + Twine(i) +
" doesn't match function result type");
return false;
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

View File

@ -19,6 +19,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"

View File

@ -25,7 +25,6 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/Passes.h"

View File

@ -27,7 +27,6 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/LoopUtils.h"

View File

@ -26,7 +26,6 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -49,7 +49,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -23,7 +23,6 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"

View File

@ -30,7 +30,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"

View File

@ -32,7 +32,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"

View File

@ -20,9 +20,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;

View File

@ -29,7 +29,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instruction.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -25,7 +25,6 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"

View File

@ -26,7 +26,6 @@
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"

View File

@ -21,7 +21,6 @@
#include "mlir/EDSC/MLIREmitter.h"
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"