forked from OSchip/llvm-project
Add Ch.5 of the toy tutorial.
This chapter adds a partial lowering of toy operations, all but PrintOp, to a combination of the Affine and Std dialects. This chapter focuses on introducing the conversion framework, the benefits of partial lowering, and how easily dialects may co-exist in the IR. PiperOrigin-RevId: 275150649
This commit is contained in:
parent
7045471913
commit
1ba9bb0507
|
@ -55,7 +55,8 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
|||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||
/// in the op definition.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
|
@ -63,6 +64,8 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
|||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op.emitOpError(
|
||||
|
@ -70,7 +73,9 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
|||
"attribute: ")
|
||||
<< attrType.getRank() << " != " << resultType.getRank();
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
|
||||
// Check that each of the dimensions match between the two types.
|
||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
return op.emitOpError(
|
||||
"return type shape mismatches its attribute at dimension ")
|
||||
|
|
|
@ -118,6 +118,8 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
|||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op.emitOpError(
|
||||
|
@ -125,7 +127,9 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
|||
"attribute: ")
|
||||
<< attrType.getRank() << " != " << resultType.getRank();
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
|
||||
// Check that each of the dimensions match between the two types.
|
||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
return op.emitOpError(
|
||||
"return type shape mismatches its attribute at dimension ")
|
||||
|
|
|
@ -79,7 +79,7 @@ public:
|
|||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
// have on the Toy operations.
|
||||
if (failed(mlir::verify(theModule))) {
|
||||
theModule.emitError("Module verification error");
|
||||
theModule.emitError("module verification error");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -229,7 +229,7 @@ private:
|
|||
if (auto *variable = symbolTable.lookup(expr.getName()))
|
||||
return variable;
|
||||
|
||||
emitError(loc(expr.loc()), "Error: unknown variable '")
|
||||
emitError(loc(expr.loc()), "error: unknown variable '")
|
||||
<< expr.getName() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -289,7 +289,8 @@ private:
|
|||
auto dataAttribute =
|
||||
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
|
||||
|
||||
// Build the MLIR op `toy.constant`.
|
||||
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
|
||||
// method.
|
||||
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
|
||||
}
|
||||
|
||||
|
@ -389,7 +390,7 @@ private:
|
|||
auto init = vardecl.getInitVal();
|
||||
if (!init) {
|
||||
emitError(loc(vardecl.loc()),
|
||||
"Missing initializer in variable declaration");
|
||||
"missing initializer in variable declaration");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,40 +1,42 @@
|
|||
add_subdirectory(include)
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Support
|
||||
)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
|
||||
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
|
||||
add_public_tablegen_target(ToyCh5CombineIncGen)
|
||||
|
||||
add_toy_chapter(toyc-ch5
|
||||
toyc.cpp
|
||||
parser/AST.cpp
|
||||
mlir/EarlyLowering.cpp
|
||||
mlir/LateLowering.cpp
|
||||
mlir/MLIRGen.cpp
|
||||
mlir/Dialect.cpp
|
||||
mlir/DeadFunctionEliminationPass.cpp
|
||||
mlir/LowerToAffineLoops.cpp
|
||||
mlir/ShapeInferencePass.cpp
|
||||
mlir/ToyDialect.cpp
|
||||
mlir/ToyCombine.cpp
|
||||
)
|
||||
|
||||
add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen)
|
||||
add_dependencies(toyc-ch5 ToyCh5OpsIncGen)
|
||||
add_dependencies(toyc-ch5 ToyCh5CombineIncGen)
|
||||
add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen)
|
||||
include_directories(include/)
|
||||
include_directories(../../Linalg/Linalg1/include/)
|
||||
include_directories(../../Linalg/Linalg2/include/)
|
||||
include_directories(../../Linalg/Linalg3/include/)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
|
||||
target_link_libraries(toyc-ch5
|
||||
PRIVATE
|
||||
Linalg3DialectConstruction
|
||||
Linalg3
|
||||
Linalg2
|
||||
Linalg1
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
MLIRSupport
|
||||
)
|
||||
MLIRStandardOps
|
||||
MLIRTransforms)
|
||||
|
||||
whole_archive_link(toyc-ch5
|
||||
MLIRAffineOps
|
||||
MLIRStandardOps
|
||||
)
|
||||
|
||||
)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(toy)
|
|
@ -33,10 +33,9 @@
|
|||
|
||||
namespace toy {
|
||||
|
||||
/// A variable
|
||||
/// A variable type with shape information.
|
||||
struct VarType {
|
||||
enum { TY_FLOAT, TY_INT } elt_ty;
|
||||
std::vector<int> shape;
|
||||
std::vector<int64_t> shape;
|
||||
};
|
||||
|
||||
/// Base class for all expression nodes.
|
||||
|
@ -50,9 +49,7 @@ public:
|
|||
Expr_Var,
|
||||
Expr_BinOp,
|
||||
Expr_Call,
|
||||
Expr_Print, // builtin
|
||||
Expr_If,
|
||||
Expr_For,
|
||||
Expr_Print,
|
||||
};
|
||||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
|
@ -85,7 +82,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a literal value.
|
||||
class LiteralExprAST : public ExprAST {
|
||||
std::vector<std::unique_ptr<ExprAST>> values;
|
||||
std::vector<int64_t> dims;
|
||||
|
@ -116,7 +113,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for defining a variable.
|
||||
class VarDeclExprAST : public ExprAST {
|
||||
std::string name;
|
||||
VarType type;
|
||||
|
@ -136,7 +133,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a return operator.
|
||||
class ReturnExprAST : public ExprAST {
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
add_public_tablegen_target(ToyCh5OpsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
|
||||
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen)
|
|
@ -16,7 +16,7 @@
|
|||
// =============================================================================
|
||||
//
|
||||
// This file implements the IR Dialect for the Toy language.
|
||||
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
|
||||
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -25,369 +25,31 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class Builder;
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// This is the definition of the Toy dialect. A dialect inherits from
|
||||
/// mlir::Dialect and register custom operations and types (in its constructor).
|
||||
/// It can also overriding general behavior of dialects exposed as virtual
|
||||
/// method, for example regarding verification and parsing/printing.
|
||||
/// mlir::Dialect and registers custom attributes, operations, and types (in its
|
||||
/// constructor). It can also override some general behavior exposed via virtual
|
||||
/// methods.
|
||||
class ToyDialect : public mlir::Dialect {
|
||||
public:
|
||||
explicit ToyDialect(mlir::MLIRContext *ctx);
|
||||
|
||||
/// Parse a type registered to this dialect. Overriding this method is
|
||||
/// required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
mlir::Type parseType(llvm::StringRef tyData,
|
||||
mlir::Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect. Overriding this method is
|
||||
/// only required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
static llvm::StringRef getDialectNamespace() { return "toy"; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////// Custom Types for the Dialect ///////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
struct ToyArrayTypeStorage;
|
||||
}
|
||||
|
||||
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
|
||||
enum ToyTypeKind {
|
||||
// The enum starts at the range reserved for this dialect.
|
||||
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
|
||||
TOY_ARRAY,
|
||||
};
|
||||
|
||||
/// Type for Toy arrays.
|
||||
/// In MLIR Types are reference to immutable and uniqued objects owned by the
|
||||
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
|
||||
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
|
||||
/// provides the public facade API to interact with the type.
|
||||
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
|
||||
detail::ToyArrayTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Returns the dimensions for this array, or and empty range for a generic
|
||||
/// array.
|
||||
llvm::ArrayRef<int64_t> getShape();
|
||||
|
||||
/// Predicate to test if this array is generic (shape haven't been inferred
|
||||
/// yet).
|
||||
bool isGeneric() { return getShape().empty(); }
|
||||
|
||||
/// Return the rank of this array (0 if it is generic).
|
||||
int getRank() { return getShape().size(); }
|
||||
|
||||
/// Return the type of individual elements in the array.
|
||||
mlir::Type getElementType();
|
||||
|
||||
/// Get a MemRef equivalent to this array type.
|
||||
mlir::MemRefType toMemref();
|
||||
|
||||
/// Get the unique instance of this Type from the context.
|
||||
/// A ToyArrayType is only defined by the shape of the array.
|
||||
static ToyArrayType get(mlir::MLIRContext *context,
|
||||
llvm::ArrayRef<int64_t> shape = {});
|
||||
|
||||
/// Support method to enable LLVM-style RTTI type casting.
|
||||
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constant operation turns a literal into an SSA value. The data is attached
|
||||
/// to the operation as an attribute. For example:
|
||||
///
|
||||
/// %0 = "toy.constant"()
|
||||
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
|
||||
/// : () -> !toy<"array<2, 3>">
|
||||
///
|
||||
/// An operation inherits from `class Op` and specifies optional traits. Here we
|
||||
/// indicate that `toy.constant` does not have any operands and returns a single
|
||||
/// result. The traits provide some utilities methods for the operation, for
|
||||
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
|
||||
/// available.
|
||||
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
/// This is the name used by MLIR to match an operation to this class during
|
||||
/// parsing.
|
||||
static llvm::StringRef getOperationName() { return "toy.constant"; }
|
||||
|
||||
/// The operation can have extra verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populates the `state` that MLIR uses to create operations.
|
||||
/// The `toy.constant` operation does not have arguments but attaches a
|
||||
/// constant array as an attribute and returns it as an SSA value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
mlir::DenseElementsAttr value);
|
||||
|
||||
/// Similar to the one above, but takes a single float and returns a
|
||||
/// !toy<"array<1>">.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value);
|
||||
|
||||
mlir::DenseElementsAttr getValue() {
|
||||
return getAttr("value").cast<mlir::DenseElementsAttr>();
|
||||
}
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Generic calls represent calls to a user defined function that needs to
|
||||
/// be specialized for the shape of its arguments. The callee name is attached
|
||||
/// as a literal string as an attribute. The arguments list must match the
|
||||
/// arguments expected by the callee. For example:
|
||||
///
|
||||
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
|
||||
/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
|
||||
///
|
||||
/// This is only valid if a function named "my_func" exists and takes two
|
||||
/// arguments.
|
||||
class GenericCallOp
|
||||
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
/// MLIR will use this to register the operation with the parser/printer.
|
||||
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to the builder to allow:
|
||||
/// mlir::Builder::create<GenericCallOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.generic_call` operation accepts a callee name and a list of
|
||||
/// arguments for the call.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::StringRef callee,
|
||||
llvm::ArrayRef<mlir::Value *> arguments);
|
||||
|
||||
/// Return the name of the callee.
|
||||
llvm::StringRef getCalleeName();
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Return operations terminate blocks (and functions as well). They take a
|
||||
/// single argument and the type must match the function return type.
|
||||
class ReturnOp
|
||||
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.return"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.return` operation accepts an optional single array as an argument
|
||||
/// and does not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value = nullptr);
|
||||
|
||||
/// Return true if there is a returned value.
|
||||
bool hasOperand() { return 0 != getNumOperands(); }
|
||||
|
||||
/// Helper to return the optional operand. Caller must check if the operand
|
||||
/// is present before calling this.
|
||||
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// The print builtin takes a single array argument and does not return any.
|
||||
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.print"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.print` operation accepts a single array as argument and does
|
||||
/// not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.transpose"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<TransposeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.transpose` operation accepts a single array as argument and
|
||||
/// returns the transposed array as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
static void
|
||||
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Reshape operation is transforming its input array into a new array with the
|
||||
/// same number of elements but different shapes. For example:
|
||||
///
|
||||
/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>">
|
||||
///
|
||||
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.reshape"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<ReshapeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.reshape` operation accepts a single array as argument and
|
||||
/// returns the array with the specified reshapedType as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType);
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
static void
|
||||
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Binary operation implementing a multiplication. For two-dimensional array
|
||||
/// a matrix multiplication is implemented, while for one dimensional array a
|
||||
/// dot product is performed.
|
||||
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.mul"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Element wise addition of two arrays. The shape must match.
|
||||
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.add"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// AllocOp is a temporary operation for buffer allocation, created as part of
|
||||
/// partial lowering.
|
||||
class AllocOp : public mlir::Op<AllocOp, mlir::OpTrait::ZeroOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.alloc"; }
|
||||
|
||||
/// Interface to mlir::Builder::create<AllocOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// `toy.alloc` does not have any argument and returns a toy array.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Type retType);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// FIXME: should be in std?
|
||||
class TypeCastOp : public mlir::Op<TypeCastOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.cast"; }
|
||||
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, mlir::Type destTy);
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
static void
|
||||
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
/// Include the auto-generated header file containing the declarations of the
|
||||
/// toy operations.
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.h.inc"
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace toy {
|
|||
|
||||
/// Structure definition a location in a file.
|
||||
struct Location {
|
||||
std::shared_ptr<std::string> file; ///< filename
|
||||
std::shared_ptr<std::string> file; ///< filename.
|
||||
int line; ///< line number.
|
||||
int col; ///< column number.
|
||||
};
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
//===- Lowering.h - Lexer for the Toy language ----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file exposes the interface to the lowering for Toy. It is divided in
|
||||
// two parts: an *early lowering* that emits operations in the `Linalg`
|
||||
// dialects for a subset of the Toy IR, and a *late lowering* that materializes
|
||||
// buffers and converts all operations and type to the LLVM dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_
|
||||
#define MLIR_EXAMPLES_TOY_LOWERING_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
class DialectConversion;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
/// Create a pass for lowering operations in the `Linalg` dialects, for a subset
|
||||
/// of the Toy IR (matmul).
|
||||
std::unique_ptr<mlir::Pass> createEarlyLoweringPass();
|
||||
|
||||
/// Create a pass for the late lowering toward LLVM dialect.
|
||||
std::unique_ptr<mlir::Pass> createLateLoweringPass();
|
||||
|
||||
} // namespace toy
|
||||
|
||||
#endif // MLIR_EXAMPLES_TOY_LOWERING_H_
|
|
@ -0,0 +1,272 @@
|
|||
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef TOY_OPS
|
||||
#else
|
||||
#define TOY_OPS
|
||||
|
||||
#ifdef MLIR_CALLINTERFACES
|
||||
#else
|
||||
include "mlir/Analysis/CallInterfaces.td"
|
||||
#endif // MLIR_CALLINTERFACES
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
// Provide a definition of the 'toy' dialect in the ODS framework so that we
|
||||
// can define our operations.
|
||||
def Toy_Dialect : Dialect {
|
||||
let name = "toy";
|
||||
let cppNamespace = "toy";
|
||||
}
|
||||
|
||||
// Base class for toy dialect operations. This operation inherits from the base
|
||||
// `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Toy_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define a toy operation by inherting from our base 'Toy_Op' class above.
|
||||
// Here we provide the mnemonic and a list of traits for the operation. The
|
||||
// constant operation is marked as 'NoSideEffect' as it is a pure operation
|
||||
// and may be removed if dead.
|
||||
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||
// Provide a summary and description for this operation. This can be used to
|
||||
// auto-generate documenatation of the operations within our dialect.
|
||||
let summary = "constant";
|
||||
let description = [{
|
||||
Constant operation turns a literal into an SSA value. The data is attached
|
||||
to the operation as an attribute. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.constant"()
|
||||
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
|
||||
: () -> tensor<2x3xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
// The constant operation takes an attribute as the only input.
|
||||
let arguments = (ins F64ElementsAttr:$value);
|
||||
|
||||
// The constant operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the constant operation. These method populates
|
||||
// the `state` that MLIR uses to create operations, i.e. these are used when
|
||||
// using `builder.create<ConstantOp>(...)`.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"DenseElementsAttr value", [{
|
||||
build(builder, result, value.getType(), value);
|
||||
}]>,
|
||||
|
||||
// Build a constant with a given constant floating-point value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
||||
buildConstantOp(builder, result, value);
|
||||
}]>
|
||||
];
|
||||
|
||||
// Invoke a static verify method to verify this constant operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def AddOp : Toy_Op<"add",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise addition operation";
|
||||
let description = [{
|
||||
The "add" operation performs element-wise addition between two tensors.
|
||||
The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building an AddOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
[DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
let summary = "generic call operation";
|
||||
let description = [{
|
||||
Generic calls represent calls to a user defined function that needs to
|
||||
be specialized for the shape of its arguments. The callee name is attached
|
||||
as a symbol reference via an attribute. The arguments list must match the
|
||||
arguments expected by the callee. For example:
|
||||
|
||||
```mlir
|
||||
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
|
||||
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
```
|
||||
|
||||
This is only valid if a function named "my_func" exists and takes two
|
||||
arguments.
|
||||
}];
|
||||
|
||||
// The generic call operation takes a symbol reference attribute as the
|
||||
// callee, and inputs for the call.
|
||||
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
|
||||
|
||||
// The generic call operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the generic call operation.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
||||
buildGenericCallOp(builder, result, callee, arguments);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def MulOp : Toy_Op<"mul",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise multiplication operation";
|
||||
let description = [{
|
||||
The "mul" operation performs element-wise multiplication between two
|
||||
tensors. The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building a MulOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildMulOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
def PrintOp : Toy_Op<"print"> {
|
||||
let summary = "print operation";
|
||||
let description = [{
|
||||
The "print" builtin operation prints a given input tensor, and produces
|
||||
no results.
|
||||
}];
|
||||
|
||||
// The print operation takes an input tensor to print.
|
||||
// We also allow a F64MemRef to enable interop during partial lowering.
|
||||
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
|
||||
}
|
||||
|
||||
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
|
||||
let summary = "tensor reshape operation";
|
||||
let description = [{
|
||||
Reshape operation is transforming its input tensor into a new tensor with
|
||||
the same number of elements but different shapes. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// We expect that the reshape operation returns a statically shaped tensor.
|
||||
let results = (outs StaticShapeTensorOf<[F64]>);
|
||||
}
|
||||
|
||||
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||
let summary = "return operation";
|
||||
let description = [{
|
||||
The "return" operation represents a return operation within a function.
|
||||
The operation takes an optional tensor operand and produces no results.
|
||||
The operand type must match the signature of the function that contains
|
||||
the operation. For example:
|
||||
|
||||
```mlir
|
||||
func @foo() -> tensor<2xf64> {
|
||||
...
|
||||
toy.return %0 : tensor<2xf64>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
// The return operation takes an optional input operand to return. This
|
||||
// value must match the return type of the enclosing function.
|
||||
let arguments = (ins Variadic<F64Tensor>:$input);
|
||||
|
||||
// Allow building a ReturnOp with no return operand.
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
||||
>];
|
||||
|
||||
// Provide extra utility definitions on the c++ operation class definition.
|
||||
let extraClassDeclaration = [{
|
||||
bool hasOperand() { return getNumOperands() != 0; }
|
||||
}];
|
||||
|
||||
// Invoke a static verify method to verify this return operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def TransposeOp : Toy_Op<"transpose",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor);
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// Allow building a TransposeOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
||||
buildTransposeOp(b, result, input);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
#endif // TOY_OPS
|
|
@ -26,10 +26,16 @@
|
|||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createShapeInferencePass();
|
||||
} // namespace toy
|
||||
std::unique_ptr<Pass> createDeadFunctionEliminationPass();
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
/// Create a pass for lowering to operations in the `Affine` and `Std` dialects,
|
||||
/// for a subset of the Toy IR (e.g. matmul).
|
||||
std::unique_ptr<mlir::Pass> createLowerToAffinePass();
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_PASSES_H
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declarations of the shape inference interfaces defined
|
||||
// in ShapeInferenceInterface.td.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
||||
#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace toy {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "toy/ShapeInferenceOpInterfaces.h.inc"
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
|
@ -0,0 +1,38 @@
|
|||
//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Shape Inference Op Interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
#define SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"void", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
|
@ -0,0 +1,68 @@
|
|||
//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Module level pass performing dead function
|
||||
// elimination. This is required as a post-processing step after function
|
||||
// inlining.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "toy/Passes.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
/// This is a simple function DCE pass that deletes all non-main functions after
|
||||
/// inlining.
|
||||
/// TODO(riverriddle) This is only necessary because MLIR currently does not
|
||||
/// have generic DCE support for functions.
|
||||
class DeadFunctionEliminationPass
|
||||
: public mlir::ModulePass<DeadFunctionEliminationPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
mlir::ModuleOp module = getModule();
|
||||
mlir::SymbolTable moduleSymTable(module);
|
||||
|
||||
// Eliminate non-main functions.
|
||||
auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
|
||||
for (mlir::FuncOp func :
|
||||
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
|
||||
if (func != mainFn)
|
||||
func.erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Create a pass that eliminates inlined functions in toy.
|
||||
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
|
||||
return std::make_unique<DeadFunctionEliminationPass>();
|
||||
}
|
|
@ -0,0 +1,256 @@
|
|||
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::toy;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyInlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class defines the interface for handling inlining with Toy
|
||||
/// operations.
|
||||
struct ToyInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// All operations within toy can be inlined.
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Handle the given inlined terminator(toy.return) by replacing it with a new
|
||||
/// operation as necessary.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value *> valuesToRepl) const final {
|
||||
// Only "toy.return" needs to be handled here.
|
||||
auto returnOp = cast<ReturnOp>(op);
|
||||
|
||||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
/// from this dialect, and a callable region. This method should generate an
|
||||
/// operation that takes 'input' as the only operand, and produces a single
|
||||
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
/// should be returned.
|
||||
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
|
||||
Type resultType,
|
||||
Location conversionLoc) const final {
|
||||
return builder.create<CastOp>(conversionLoc, resultType, input);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "toy/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ToyInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
double value) {
|
||||
auto dataType = builder->getTensorType({}, builder->getF64Type());
|
||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||
}
|
||||
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||
/// in the op definition.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op.emitOpError(
|
||||
"return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
<< attrType.getRank() << " != " << resultType.getRank();
|
||||
}
|
||||
|
||||
// Check that each of the dimensions match between the two types.
|
||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
return op.emitOpError(
|
||||
"return type shape mismatches its attribute at dimension ")
|
||||
<< dim << ": " << attrType.getShape()[dim]
|
||||
<< " != " << resultType.getShape()[dim];
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
|
||||
static void buildGenericCallOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, StringRef callee,
|
||||
ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
/// call interface.
|
||||
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||
|
||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() {
|
||||
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank)
|
||||
return;
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else if (lhsRank == 2) {
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(ReturnOp op) {
|
||||
// We know that the parent operation is a function, because of the 'HasParent'
|
||||
// trait attached to the operation definition.
|
||||
auto function = cast<FuncOp>(op.getParentOp());
|
||||
|
||||
/// ReturnOps can only have a single optional operand.
|
||||
if (op.getNumOperands() > 1)
|
||||
return op.emitOpError() << "expects at most 1 return operand";
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function.getType().getResults();
|
||||
if (op.getNumOperands() != results.size())
|
||||
return op.emitOpError()
|
||||
<< "does not return the same number of values ("
|
||||
<< op.getNumOperands() << ") as the enclosing function ("
|
||||
<< results.size() << ")";
|
||||
|
||||
// If the operation does not have an input, we are done.
|
||||
if (!op.hasOperand())
|
||||
return mlir::success();
|
||||
|
||||
auto inputType = *op.operand_type_begin();
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
return mlir::success();
|
||||
|
||||
return op.emitError() << "type of return operand ("
|
||||
<< *op.operand_type_begin()
|
||||
<< ") doesn't match function result type ("
|
||||
<< results.front() << ")";
|
||||
}
|
||||
|
||||
static void buildTransposeOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, mlir::Value *value) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(value);
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.cpp.inc"
|
|
@ -1,148 +0,0 @@
|
|||
//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements early lowering of Toy IR to Linalg Dialect: we only
|
||||
// lower the computationally intensive part of the program (matmul...) to a
|
||||
// dialect specialized for optimizations.
|
||||
//
|
||||
// This is intended to showcase how multiple dialects can cohabit in the same
|
||||
// function. After this lowering, you would still have toy.print in the IR for
|
||||
// example.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Intrinsics.h"
|
||||
#include "linalg1/ViewOp.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Utility function for type casting: this is making the type checker happy,
|
||||
/// while delaying the actual work involved to convert the type. Most of the
|
||||
/// time both side of the cast (producer and consumer) will be lowered to a
|
||||
/// dialect like LLVM and end up with the same LLVM representation, at which
|
||||
/// point this becomes a no-op and is eliminated.
|
||||
Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) {
|
||||
if (val->getType() == destTy)
|
||||
return val;
|
||||
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
|
||||
/// lowered to a memref during buffer allocation, at which point the type cast
|
||||
/// becomes useless.
|
||||
Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) {
|
||||
if (val->getType().isa<MemRefType>())
|
||||
return val;
|
||||
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
|
||||
if (!toyArrayTy)
|
||||
return val;
|
||||
return typeCast(builder, val, toyArrayTy.toMemref());
|
||||
}
|
||||
|
||||
/// Lower toy.mul to Linalg `matmul`.
|
||||
///
|
||||
/// This class inherit from `ConversionPattern` and override `rewrite`,
|
||||
/// similarly to the PatternRewriter introduced in the previous chapter.
|
||||
/// It will be called by the DialectConversion framework (see `LateLowering`
|
||||
/// class below).
|
||||
class MulOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit MulOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
using namespace edsc;
|
||||
using intrinsics::constant_index;
|
||||
using linalg::intrinsics::range;
|
||||
using linalg::intrinsics::view;
|
||||
toy::MulOp mul = cast<toy::MulOp>(op);
|
||||
auto loc = mul.getLoc();
|
||||
Value *result = memRefTypeCast(
|
||||
rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
|
||||
.getResult());
|
||||
Value *lhs = memRefTypeCast(rewriter, operands[0]);
|
||||
auto memrefLHSTy = lhs->getType().cast<MemRefType>();
|
||||
Value *rhs = memRefTypeCast(rewriter, operands[1]);
|
||||
auto memrefRHSTy = rhs->getType().cast<MemRefType>();
|
||||
mlir::edsc::ScopedContext scope(rewriter, loc);
|
||||
edsc::ValueHandle r0 =
|
||||
range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)),
|
||||
constant_index(1));
|
||||
edsc::ValueHandle r1 =
|
||||
range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)),
|
||||
constant_index(1));
|
||||
edsc::ValueHandle r2 =
|
||||
range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)),
|
||||
constant_index(1));
|
||||
auto lhsView = view(lhs, {r0, r1});
|
||||
auto rhsView = view(rhs, {r1, r2});
|
||||
auto resultView = view(result, {r0, r2});
|
||||
rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// This is lowering to Linalg the parts that are computationally intensive
|
||||
/// (like matmul for example...) while keeping the rest of the code in the Toy
|
||||
/// dialect.
|
||||
struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
|
||||
void runOnFunction() override {
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<MulOpConversion>(&getContext());
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
||||
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createEarlyLoweringPass() {
|
||||
return std::make_unique<EarlyLoweringPass>();
|
||||
}
|
||||
} // namespace toy
|
|
@ -1,470 +0,0 @@
|
|||
//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements late lowering of IR mixing Toy and Linalg to LLVM.
|
||||
// It involves intemerdiate steps:
|
||||
// -
|
||||
// - a mix of affine and standard dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "linalg1/Intrinsics.h"
|
||||
#include "linalg1/ViewOp.h"
|
||||
#include "linalg3/ConvertToLLVMDialect.h"
|
||||
#include "linalg3/TensorOps.h"
|
||||
#include "linalg3/Transforms.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Utility function for type casting: this is making the type checker happy,
|
||||
/// while delaying the actual work involved to convert the type. Most of the
|
||||
/// time both side of the cast (producer and consumer) will be lowered to a
|
||||
/// dialect like LLVM and end up with the same LLVM representation, at which
|
||||
/// point this becomes a no-op and is eliminated.
|
||||
Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) {
|
||||
if (val->getType() == destTy)
|
||||
return val;
|
||||
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
|
||||
/// lowered to a memref during buffer allocation, at which point the type cast
|
||||
/// becomes useless.
|
||||
Value *memRefTypeCast(PatternRewriter &builder, Value *val) {
|
||||
if (val->getType().isa<MemRefType>())
|
||||
return val;
|
||||
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
|
||||
if (!toyArrayTy)
|
||||
return val;
|
||||
return typeCast(builder, val, toyArrayTy.toMemref());
|
||||
}
|
||||
|
||||
/// Lower a toy.add to an affine loop nest.
|
||||
///
|
||||
/// This class inherit from `ConversionPattern` and override `rewrite`,
|
||||
/// similarly to the PatternRewriter introduced in the previous chapter.
|
||||
/// It will be called by the DialectConversion framework (see `LateLowering`
|
||||
/// class below).
|
||||
class AddOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit AddOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
|
||||
|
||||
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
|
||||
/// is setup with a new function, the `operands` array has been populated with
|
||||
/// the rewritten operands for `op` in the new function.
|
||||
/// The results created by the new IR with the builder are returned, and their
|
||||
/// number must match the number of result of `op`.
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto add = cast<toy::AddOp>(op);
|
||||
auto loc = add.getLoc();
|
||||
// Create a `toy.alloc` operation to allocate the output buffer for this op.
|
||||
Value *result = memRefTypeCast(
|
||||
rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType())
|
||||
.getResult());
|
||||
Value *lhs = memRefTypeCast(rewriter, operands[0]);
|
||||
Value *rhs = memRefTypeCast(rewriter, operands[1]);
|
||||
|
||||
using namespace edsc;
|
||||
ScopedContext scope(rewriter, loc);
|
||||
ValueHandle zero = intrinsics::constant_index(0);
|
||||
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
|
||||
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||
IndexHandle i, j, M(vRes.ub(0));
|
||||
if (vRes.rank() == 1) {
|
||||
LoopNestBuilder({&i}, {zero}, {M},
|
||||
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
|
||||
} else {
|
||||
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
|
||||
IndexHandle N(vRes.ub(1));
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
|
||||
{1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
|
||||
}
|
||||
|
||||
// Return the newly allocated buffer, with a type.cast to preserve the
|
||||
// consumers.
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lowers `toy.print` to a loop nest calling `printf` on every individual
|
||||
/// elements of the array.
|
||||
class PrintOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit PrintOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
|
||||
|
||||
auto print = cast<toy::PrintOp>(op);
|
||||
auto loc = print.getLoc();
|
||||
// We will operate on a MemRef abstraction, we use a type.cast to get one
|
||||
// if our operand is still a Toy array.
|
||||
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
||||
Type retTy = printfFunc.getType().getFunctionResultType();
|
||||
|
||||
// Create our loop nest now
|
||||
using namespace edsc;
|
||||
using extractvalue = intrinsics::ValueBuilder<LLVM::ExtractValueOp>;
|
||||
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
|
||||
ScopedContext scope(rewriter, loc);
|
||||
ValueHandle zero = intrinsics::constant_index(0);
|
||||
ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
|
||||
MemRefView vOp(operand);
|
||||
IndexedValue iOp(operand);
|
||||
IndexHandle i, j, M(vOp.ub(0));
|
||||
|
||||
auto *dialect = op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto i8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||
|
||||
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
|
||||
if (vOp.rank() == 1) {
|
||||
// clang-format off
|
||||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
llvmCall(retTy,
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
|
||||
iOp(i)});
|
||||
});
|
||||
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc),
|
||||
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
|
||||
// clang-format on
|
||||
} else {
|
||||
IndexHandle N(vOp.ub(1));
|
||||
// clang-format off
|
||||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
LoopBuilder(&j, zero, N, 1)([&]{
|
||||
llvmCall(
|
||||
retTy,
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
|
||||
iOp(i, j)});
|
||||
});
|
||||
llvmCall(
|
||||
retTy,
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
// Turn a string into a toy.alloc (malloc/free abstraction) and a sequence
|
||||
// of stores into the buffer, and return a MemRef into the buffer.
|
||||
Value *getConstantCharBuffer(PatternRewriter &builder, Location loc,
|
||||
StringRef data) const {
|
||||
auto retTy =
|
||||
builder.getMemRefType(data.size() + 1, builder.getIntegerType(8));
|
||||
Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult();
|
||||
using namespace edsc;
|
||||
using intrinsics::constant_index;
|
||||
using intrinsics::constant_int;
|
||||
ScopedContext scope(builder, loc);
|
||||
MemRefView vOp(result);
|
||||
IndexedValue iOp(result);
|
||||
for (uint64_t i = 0; i < data.size(); ++i) {
|
||||
iOp(constant_index(i)) = constant_int(data[i], 8);
|
||||
}
|
||||
iOp(constant_index(data.size())) = constant_int(0, 8);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Return the prototype declaration for printf in the module, create it if
|
||||
/// necessary.
|
||||
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
|
||||
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
|
||||
if (printfFunc)
|
||||
return printfFunc;
|
||||
|
||||
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
auto *dialect =
|
||||
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
||||
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
|
||||
/*isVarArg=*/true);
|
||||
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
|
||||
printfTy,
|
||||
ArrayRef<NamedAttribute>());
|
||||
}
|
||||
};
|
||||
|
||||
/// Lowers constant to a sequence of store in a buffer.
|
||||
class ConstantOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ConstantOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
|
||||
auto loc = cstOp.getLoc();
|
||||
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
|
||||
auto shape = retTy.getShape();
|
||||
Value *result = memRefTypeCast(
|
||||
rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult());
|
||||
|
||||
auto cstValue = cstOp.getValue();
|
||||
auto f64Ty = rewriter.getF64Type();
|
||||
using namespace edsc;
|
||||
using intrinsics::constant_float;
|
||||
using intrinsics::constant_index;
|
||||
ScopedContext scope(rewriter, loc);
|
||||
MemRefView vOp(result);
|
||||
IndexedValue iOp(result);
|
||||
for (uint64_t i = 0, ie = shape[0]; i < ie; ++i) {
|
||||
if (shape.size() == 1) {
|
||||
auto value = cstValue.getValue<APFloat>(ArrayRef<uint64_t>{i});
|
||||
iOp(constant_index(i)) = constant_float(value, f64Ty);
|
||||
continue;
|
||||
}
|
||||
for (uint64_t j = 0, je = shape[1]; j < je; ++j) {
|
||||
auto value = cstValue.getValue<APFloat>(ArrayRef<uint64_t>{i, j});
|
||||
iOp(constant_index(i), constant_index(j)) =
|
||||
constant_float(value, f64Ty);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower transpose operation to an affine loop nest.
|
||||
class TransposeOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit TransposeOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto transpose = cast<toy::TransposeOp>(op);
|
||||
auto loc = transpose.getLoc();
|
||||
Value *result = memRefTypeCast(
|
||||
rewriter,
|
||||
rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType())
|
||||
.getResult());
|
||||
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
||||
|
||||
using namespace edsc;
|
||||
ScopedContext scope(rewriter, loc);
|
||||
ValueHandle zero = intrinsics::constant_index(0);
|
||||
MemRefView vRes(result), vOperand(operand);
|
||||
IndexedValue iRes(result), iOperand(operand);
|
||||
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
|
||||
// clang-format off
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||
iRes(i, j) = iOperand(j, i);
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Lower toy.return to standard return operation.
|
||||
class ReturnOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ReturnOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Argument is optional, handle both cases.
|
||||
if (op->getNumOperands())
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// This is the main class registering our individual converter classes with
|
||||
/// the DialectConversion framework in MLIR.
|
||||
class ToyTypeConverter : public TypeConverter {
|
||||
protected:
|
||||
/// Convert a Toy type, this gets called for block and region arguments, and
|
||||
/// attributes.
|
||||
Type convertType(Type t) override {
|
||||
if (auto array = t.dyn_cast<toy::ToyArrayType>())
|
||||
return array.toMemref();
|
||||
return t;
|
||||
}
|
||||
|
||||
/// Materialize a conversion to allow for partial lowering of types.
|
||||
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
|
||||
ArrayRef<Value *> inputs,
|
||||
Location loc) override {
|
||||
assert(inputs.size() == 1 && "expected only one input value");
|
||||
return rewriter.create<toy::TypeCastOp>(loc, inputs[0], resultType);
|
||||
}
|
||||
};
|
||||
|
||||
/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
|
||||
/// and is targeting LLVM otherwise.
|
||||
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
||||
void runOnModule() override {
|
||||
ToyTypeConverter typeConverter;
|
||||
OwningRewritePatternList toyPatterns;
|
||||
toyPatterns.insert<AddOpConversion, PrintOpConversion, ConstantOpConversion,
|
||||
TransposeOpConversion, ReturnOpConversion>(
|
||||
&getContext());
|
||||
mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(),
|
||||
typeConverter);
|
||||
|
||||
// Perform Toy specific lowering.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
|
||||
LLVM::LLVMDialect, StandardOpsDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType());
|
||||
});
|
||||
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
|
||||
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
|
||||
&typeConverter))) {
|
||||
emitError(UnknownLoc::get(getModule().getContext()),
|
||||
"error lowering Toy\n");
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
// At this point the IR is almost using only standard and affine dialects.
|
||||
// A few things remain before we emit LLVM IR. First to reuse as much of
|
||||
// MLIR as possible we will try to lower everything to the standard and/or
|
||||
// affine dialect: they already include conversion to the LLVM dialect.
|
||||
|
||||
// First patch calls type to return memref instead of ToyArray
|
||||
for (auto function : getModule().getOps<FuncOp>()) {
|
||||
function.walk([&](Operation *op) {
|
||||
auto callOp = dyn_cast<CallOp>(op);
|
||||
if (!callOp)
|
||||
return;
|
||||
if (!callOp.getNumResults())
|
||||
return;
|
||||
auto retToyTy =
|
||||
callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>();
|
||||
if (!retToyTy)
|
||||
return;
|
||||
callOp.getResult(0)->setType(retToyTy.toMemref());
|
||||
});
|
||||
}
|
||||
|
||||
for (auto function : getModule().getOps<FuncOp>()) {
|
||||
function.walk([&](Operation *op) {
|
||||
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
|
||||
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
|
||||
auto result = allocTensor(allocOp);
|
||||
allocOp.replaceAllUsesWith(result);
|
||||
allocOp.erase();
|
||||
return;
|
||||
}
|
||||
// Eliminate all type.cast before lowering to LLVM.
|
||||
if (auto typeCastOp = dyn_cast<toy::TypeCastOp>(op)) {
|
||||
typeCastOp.replaceAllUsesWith(typeCastOp.getOperand());
|
||||
typeCastOp.erase();
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Lower Linalg to affine
|
||||
for (auto function : getModule().getOps<FuncOp>())
|
||||
linalg::lowerToLoops(function);
|
||||
|
||||
getModule().dump();
|
||||
|
||||
// Finally convert to LLVM Dialect
|
||||
linalg::convertLinalg3ToLLVM(getModule());
|
||||
}
|
||||
|
||||
/// Allocate buffers (malloc/free) for Toy operations. This can't be done as
|
||||
/// part of dialect conversion framework since we need to insert `dealloc`
|
||||
/// operations just before the return, but the conversion framework is
|
||||
/// operating in a brand new function: we don't have the return to hook the
|
||||
/// dealloc operations.
|
||||
Value *allocTensor(toy::AllocOp alloc) {
|
||||
OpBuilder builder(alloc);
|
||||
auto retTy = alloc.getResult()->getType();
|
||||
|
||||
auto memRefTy = retTy.dyn_cast<MemRefType>();
|
||||
if (!memRefTy)
|
||||
memRefTy = retTy.cast<toy::ToyArrayType>().toMemref();
|
||||
if (!memRefTy) {
|
||||
alloc.emitOpError("is expected to allocate a Toy array or a MemRef");
|
||||
llvm_unreachable("fatal error");
|
||||
}
|
||||
auto loc = alloc.getLoc();
|
||||
Value *result = builder.create<AllocOp>(loc, memRefTy).getResult();
|
||||
|
||||
// Insert a `dealloc` operation right before the `return` operations, unless
|
||||
// it is returned itself in which case the caller is responsible for it.
|
||||
alloc.getParentRegion()->walk([&](Operation *op) {
|
||||
auto returnOp = dyn_cast<ReturnOp>(op);
|
||||
if (!returnOp)
|
||||
return;
|
||||
if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc)
|
||||
return;
|
||||
builder.setInsertionPoint(returnOp);
|
||||
builder.create<DeallocOp>(alloc.getLoc(), result);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createLateLoweringPass() {
|
||||
return std::make_unique<LateLoweringPass>();
|
||||
}
|
||||
} // namespace toy
|
|
@ -0,0 +1,318 @@
|
|||
//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a partial lowering of Toy operations to a combination of
|
||||
// affine loops and standard operations. This lowering expects that all calls
|
||||
// have been inlined, and all shapes have been resolved.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Convert the given TensorType into the corresponding MemRefType.
|
||||
static MemRefType convertTensorToMemRef(TensorType type) {
|
||||
assert(type.hasRank() && "expected only ranked shapes");
|
||||
return MemRefType::get(type.getShape(), type.getElementType());
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value *insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto alloc = rewriter.create<AllocOp>(loc, type);
|
||||
|
||||
// Make sure to allocate at the beginning of the block.
|
||||
auto *parentBlock = alloc.getOperation()->getBlock();
|
||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
// Make sure to deallocate this alloc at the end of the block. This is fine
|
||||
// as toy functions have no control flow.
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
return alloc;
|
||||
}
|
||||
|
||||
/// This defines the function type used to process an iteration of a lowered
|
||||
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
|
||||
/// to the operands of the input operation, and the set of loop induction
|
||||
/// variables for the iteration. It returns a value to store at the current
|
||||
/// index of the iteration.
|
||||
using LoopIterationFn = function_ref<Value *(PatternRewriter &rewriter,
|
||||
ArrayRef<Value *> memRefOperands,
|
||||
ArrayRef<Value *> loopIvs)>;
|
||||
|
||||
static void lowerOpToLoops(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
// Create an empty affine loop for each of the dimensions within the shape.
|
||||
SmallVector<Value *, 4> loopIvs;
|
||||
for (auto dim : tensorType.getShape()) {
|
||||
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
|
||||
loop.getBody()->clear();
|
||||
loopIvs.push_back(loop.getInductionVar());
|
||||
|
||||
// Terminate the loop body and update the rewriter insertion point to the
|
||||
// beginning of the loop.
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
rewriter.create<AffineTerminatorOp>(loc);
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
}
|
||||
|
||||
// Generate a call to the processing function with the rewriter, the memref
|
||||
// operands, and the loop induction variables. This function will return the
|
||||
// value to store at the current index.
|
||||
Value *valueToStore = processIteration(rewriter, operands, loopIvs);
|
||||
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
|
||||
llvm::makeArrayRef(loopIvs));
|
||||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
}
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Binary operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename BinaryOp, typename LoweredBinaryOp>
|
||||
struct BinaryOpLowering : public ConversionPattern {
|
||||
BinaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
|
||||
ArrayRef<Value *> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
|
||||
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
|
||||
auto loadedRhs =
|
||||
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
|
||||
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Constant operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseElementsAttr constantValue = op.value();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// When lowering the constant operation, we allocate and assign the constant
|
||||
// values to a corresponding memref allocation.
|
||||
auto tensorType = op.getType().cast<TensorType>();
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
// We will be generating constant indices up-to the largest dimension.
|
||||
// Create these constants up-front to avoid large amounts of redundant
|
||||
// operations.
|
||||
auto valueShape = memRefType.getShape();
|
||||
SmallVector<Value *, 8> constantIndices;
|
||||
for (auto i : llvm::seq<int64_t>(
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
|
||||
// The constant operation represents a multi-dimensional constant, so we
|
||||
// will need to generate a store for each of the elements. The following
|
||||
// functor recursively walks the dimensions of the constant shape,
|
||||
// generating a store when the recursion hits the base case.
|
||||
SmallVector<Value *, 2> indices;
|
||||
auto valueIt = constantValue.getValues<FloatAttr>().begin();
|
||||
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
||||
// The last dimension is the base case of the recursion, at this point
|
||||
// we store the element at the given index.
|
||||
if (dimension == valueShape.size()) {
|
||||
rewriter.create<AffineStoreOp>(
|
||||
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
|
||||
llvm::makeArrayRef(indices));
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, iterate over the current dimension and add the indices to
|
||||
// the list.
|
||||
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
|
||||
indices.push_back(constantIndices[i]);
|
||||
storeElements(dimension + 1);
|
||||
indices.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
// Start the element storing recursion from the first dimension.
|
||||
storeElements(/*dimension=*/0);
|
||||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Return operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
|
||||
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// During this lowering, we expect that all function calls have been
|
||||
// inlined.
|
||||
if (op.hasOperand())
|
||||
return matchFailure();
|
||||
|
||||
// We lower "toy.return" directly to "std.return".
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Transpose operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TransposeOpLowering : public ConversionPattern {
|
||||
TransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
|
||||
ArrayRef<Value *> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the TransposeOp.
|
||||
// This allows for using the nice named accessors that are generated
|
||||
// by the ODS.
|
||||
toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
|
||||
Value *input = tranposeAdaptor.input();
|
||||
|
||||
// Transpose the elements by generating a load from the reverse
|
||||
// indices.
|
||||
SmallVector<Value *, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffineLoweringPass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a partial lowering to affine loops of the toy operations that are
|
||||
/// computationally intensive (like matmul for example...) while keeping the
|
||||
/// rest of the code in the Toy dialect.
|
||||
namespace {
|
||||
struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void ToyToAffineLoweringPass::runOnFunction() {
|
||||
auto function = getFunction();
|
||||
|
||||
// We only lower the main function as we expect that all other functions have
|
||||
// been inlined.
|
||||
if (function.getName() != "main")
|
||||
return;
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (function.getNumArguments() || function.getType().getNumResults()) {
|
||||
function.emitError("expected 'main' to have 0 inputs and 0 results");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
ConversionTarget target(getContext());
|
||||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
// a partial lowering, we explicitly mark the Toy operations that don't want
|
||||
// to lower, `toy.print`, as `legal`.
|
||||
target.addIllegalDialect<toy::ToyDialect>();
|
||||
target.addLegalOp<toy::PrintOp>();
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
|
||||
ReturnOpLowering, TransposeOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
// operations were not converted successfully.
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
|
||||
/// for a subset of the Toy IR (e.g. matmul).
|
||||
std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
|
||||
return std::make_unique<ToyToAffineLoweringPass>();
|
||||
}
|
|
@ -25,30 +25,30 @@
|
|||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir::toy;
|
||||
using namespace toy;
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::cast;
|
||||
using llvm::dyn_cast;
|
||||
using llvm::isa;
|
||||
using llvm::makeArrayRef;
|
||||
using llvm::ScopedHashTableScope;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
using std::make_unique;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -57,56 +57,43 @@ namespace {
|
|||
/// This will emit operations that are specific to the Toy language, preserving
|
||||
/// the semantics of the language and (hopefully) allow to perform accurate
|
||||
/// analysis and transformation based on these high level semantics.
|
||||
///
|
||||
/// At this point we take advantage of the "raw" MLIR APIs to create operations
|
||||
/// that haven't been registered in any way with MLIR. These operations are
|
||||
/// unknown to MLIR, custom passes could operate by string-matching the name of
|
||||
/// these operations, but no other type checking or semantic is associated with
|
||||
/// them natively by MLIR.
|
||||
class MLIRGenImpl {
|
||||
public:
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
|
||||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
/// Module operation.
|
||||
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->push_back(func);
|
||||
theModule.push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
// this won't do much, but it should at least check some structural
|
||||
// properties.
|
||||
if (failed(mlir::verify(*theModule))) {
|
||||
emitError(mlir::UnknownLoc::get(&context), "module verification error");
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
// have on the Toy operations.
|
||||
if (failed(mlir::verify(theModule))) {
|
||||
theModule.emitError("module verification error");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::move(theModule);
|
||||
return theModule;
|
||||
}
|
||||
|
||||
private:
|
||||
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
|
||||
/// the ownership of many internal structure of the IR and provide a level
|
||||
/// of "uniquing" across multiple modules (types for instance).
|
||||
mlir::MLIRContext &context;
|
||||
/// A "module" matches a Toy source file: containing a list of functions.
|
||||
mlir::ModuleOp theModule;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
/// convenience for emitting individual operations.
|
||||
/// The builder is stateful, in particular it keeeps an "insertion point":
|
||||
/// this is where the next operations will be introduced.
|
||||
std::unique_ptr<mlir::OpBuilder> builder;
|
||||
/// The builder is a helper class to create IR inside a function. The builder
|
||||
/// is stateful, in particular it keeeps an "insertion point": this is where
|
||||
/// the next operations will be introduced.
|
||||
mlir::OpBuilder builder;
|
||||
|
||||
/// The symbol table maps a variable name to a value in the current scope.
|
||||
/// Entering a function creates a new scope, and the function arguments are
|
||||
|
@ -116,37 +103,35 @@ private:
|
|||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
|
||||
loc.line, loc.col, &context);
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
/// Declare a variable in the current scope, return true if the variable
|
||||
/// Declare a variable in the current scope, return success if the variable
|
||||
/// wasn't declared yet.
|
||||
bool declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var)) {
|
||||
return false;
|
||||
}
|
||||
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var))
|
||||
return mlir::failure();
|
||||
symbolTable.insert(var, value);
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
// Arguments type are uniformly unranked tensors.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
||||
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
function.setAttr("toy.generic", builder.getUnitAttr());
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -165,29 +150,39 @@ private:
|
|||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
llvm::zip(protoArgs, entryBlock.getArguments())) {
|
||||
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
|
||||
if (failed(declare(std::get<0>(name_value)->getName(),
|
||||
std::get<1>(name_value))))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
// Set the insertion point in the builder to the beginning of the function
|
||||
// body, it will be used throughout the codegen to create operations in this
|
||||
// function.
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emited.
|
||||
// Implicitly return void if no return statement was emitted.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
ReturnOp returnOp;
|
||||
if (!entryBlock.empty())
|
||||
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
|
||||
if (!returnOp) {
|
||||
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
|
||||
} else if (returnOp.hasOperand()) {
|
||||
// Otherwise, if this return operation has an operand then add a result to
|
||||
// the function.
|
||||
function.setType(builder.getFunctionType(function.getType().getInputs(),
|
||||
getType(VarType{})));
|
||||
}
|
||||
|
||||
return function;
|
||||
|
@ -206,11 +201,11 @@ private:
|
|||
// and the result value is returned. If an error occurs we get a nullptr
|
||||
// and propagate.
|
||||
//
|
||||
mlir::Value *L = mlirGen(*binop.getLHS());
|
||||
if (!L)
|
||||
mlir::Value *lhs = mlirGen(*binop.getLHS());
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
mlir::Value *R = mlirGen(*binop.getRHS());
|
||||
if (!R)
|
||||
mlir::Value *rhs = mlirGen(*binop.getRHS());
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
auto location = loc(binop.loc());
|
||||
|
||||
|
@ -218,123 +213,113 @@ private:
|
|||
// support '+' and '*'.
|
||||
switch (binop.getOp()) {
|
||||
case '+':
|
||||
return builder->create<AddOp>(location, L, R).getResult();
|
||||
break;
|
||||
return builder.create<AddOp>(location, lhs, rhs);
|
||||
case '*':
|
||||
return builder->create<MulOp>(location, L, R).getResult();
|
||||
default:
|
||||
emitError(loc(binop.loc()), "error: invalid binary operator '")
|
||||
<< binop.getOp() << "'";
|
||||
return nullptr;
|
||||
return builder.create<MulOp>(location, lhs, rhs);
|
||||
}
|
||||
|
||||
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This is a reference to a variable in an expression. The variable is
|
||||
// expected to have been declared and so should have a value in the symbol
|
||||
// table, otherwise emit an error and return nullptr.
|
||||
/// This is a reference to a variable in an expression. The variable is
|
||||
/// expected to have been declared and so should have a value in the symbol
|
||||
/// table, otherwise emit an error and return nullptr.
|
||||
mlir::Value *mlirGen(VariableExprAST &expr) {
|
||||
if (symbolTable.count(expr.getName()))
|
||||
return symbolTable.lookup(expr.getName());
|
||||
if (auto *variable = symbolTable.lookup(expr.getName()))
|
||||
return variable;
|
||||
|
||||
emitError(loc(expr.loc()), "error: unknown variable '")
|
||||
<< expr.getName() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Emit a return operation, return true on success.
|
||||
bool mlirGen(ReturnExprAST &ret) {
|
||||
/// Emit a return operation. This will return failure if any generation fails.
|
||||
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
|
||||
auto location = loc(ret.loc());
|
||||
// `return` takes an optional expression, we need to account for it here.
|
||||
if (!ret.getExpr().hasValue()) {
|
||||
builder->create<ReturnOp>(location);
|
||||
return true;
|
||||
|
||||
// 'return' takes an optional expression, handle that case here.
|
||||
mlir::Value *expr = nullptr;
|
||||
if (ret.getExpr().hasValue()) {
|
||||
if (!(expr = mlirGen(*ret.getExpr().getValue())))
|
||||
return mlir::failure();
|
||||
}
|
||||
auto *expr = mlirGen(*ret.getExpr().getValue());
|
||||
if (!expr)
|
||||
return false;
|
||||
builder->create<ReturnOp>(location, expr);
|
||||
return true;
|
||||
|
||||
// Otherwise, this return operation has zero operands.
|
||||
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
|
||||
: ArrayRef<mlir::Value *>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
// data in an Attribute attached to a `toy.constant` operation.
|
||||
// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
// Here is an excerpt:
|
||||
//
|
||||
// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
// places where a variable is never allowed [...]. They consist of a name
|
||||
// and a [concrete attribute value](#attribute-values). It is possible to
|
||||
// attach attributes to operations, functions, and function arguments. The
|
||||
// set of expected attributes, their structure, and their interpretation
|
||||
// are all contextually dependent on what they are attached to.
|
||||
//
|
||||
// Example, the source level statement:
|
||||
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
// will be converted to:
|
||||
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
|
||||
//
|
||||
/// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
/// data in an Attribute attached to a `toy.constant` operation.
|
||||
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
/// Here is an excerpt:
|
||||
///
|
||||
/// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
/// places where a variable is never allowed [...]. They consist of a name
|
||||
/// and a concrete attribute value. The set of expected attributes, their
|
||||
/// structure, and their interpretation are all contextually dependent on
|
||||
/// what they are attached to.
|
||||
///
|
||||
/// Example, the source level statement:
|
||||
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
/// will be converted to:
|
||||
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
|
||||
///
|
||||
mlir::Value *mlirGen(LiteralExprAST &lit) {
|
||||
auto location = loc(lit.loc());
|
||||
// The attribute is a vector with an attribute per element (number) in the
|
||||
// array, see `collectData()` below for more details.
|
||||
std::vector<mlir::Attribute> data;
|
||||
auto type = getType(lit.getDims());
|
||||
|
||||
// The attribute is a vector with a floating point value per element
|
||||
// (number) in the array, see `collectData()` below for more details.
|
||||
std::vector<double> data;
|
||||
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
|
||||
std::multiplies<int>()));
|
||||
collectData(lit, data);
|
||||
|
||||
// FIXME: using a tensor type is a HACK here.
|
||||
// Can we do differently without registering a dialect? Using a string blob?
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto dataType = builder->getTensorType(lit.getDims(), elementType);
|
||||
// The type of this attribute is tensor of 64-bit floating-point with the
|
||||
// shape of the literal.
|
||||
mlir::Type elementType = builder.getF64Type();
|
||||
auto dataType = builder.getTensorType(lit.getDims(), elementType);
|
||||
|
||||
// This is the actual attribute that actually hold the list of values for
|
||||
// this array literal.
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
// This is the actual attribute that holds the list of values for this
|
||||
// tensor literal.
|
||||
auto dataAttribute =
|
||||
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
|
||||
|
||||
// Build the MLIR op `toy.constant`, only boilerplate below.
|
||||
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
|
||||
.getResult();
|
||||
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
|
||||
// method.
|
||||
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
|
||||
}
|
||||
|
||||
// Recursive helper function to accumulate the data that compose an array
|
||||
// literal. It flattens the nested structure in the supplied vector. For
|
||||
// example with this array:
|
||||
// [[1, 2], [3, 4]]
|
||||
// we will generate:
|
||||
// [ 1, 2, 3, 4 ]
|
||||
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
|
||||
// Attributes are the way MLIR attaches constant to operations and functions.
|
||||
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
|
||||
/// Recursive helper function to accumulate the data that compose an array
|
||||
/// literal. It flattens the nested structure in the supplied vector. For
|
||||
/// example with this array:
|
||||
/// [[1, 2], [3, 4]]
|
||||
/// we will generate:
|
||||
/// [ 1, 2, 3, 4 ]
|
||||
/// Individual numbers are represented as doubles.
|
||||
/// Attributes are the way MLIR attaches constant to operations.
|
||||
void collectData(ExprAST &expr, std::vector<double> &data) {
|
||||
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
|
||||
for (auto &value : lit->getValues())
|
||||
collectData(*value, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(
|
||||
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
|
||||
data.push_back(attr);
|
||||
data.push_back(cast<NumberExprAST>(expr).getValue());
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for the `transpose`
|
||||
// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
/// Emit a call expression. It emits specific operations for the `transpose`
|
||||
/// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
mlir::Value *mlirGen(CallExprAST &call) {
|
||||
llvm::StringRef callee = call.getCallee();
|
||||
auto location = loc(call.loc());
|
||||
std::string callee = call.getCallee();
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
|
||||
return builder->create<TransposeOp>(location, arg).getResult();
|
||||
}
|
||||
|
||||
// Codegen the operands first
|
||||
// Codegen the operands first.
|
||||
SmallVector<mlir::Value *, 4> operands;
|
||||
for (auto &expr : call.getArgs()) {
|
||||
auto *arg = mlirGen(*expr);
|
||||
|
@ -342,34 +327,41 @@ private:
|
|||
return nullptr;
|
||||
operands.push_back(arg);
|
||||
}
|
||||
// Calls to user-defined function are mapped to a custom call that takes
|
||||
// the callee name as an attribute.
|
||||
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
|
||||
.getResult();
|
||||
|
||||
// Builting calls have their custom operation, meaning this is a
|
||||
// straightforward emission.
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<TransposeOp>(location, operands[0]);
|
||||
}
|
||||
|
||||
// Otherwise this is a call to a user-defined function. Calls to ser-defined
|
||||
// functions are mapped to a custom call that takes the callee name as an
|
||||
// attribute.
|
||||
return builder.create<GenericCallOp>(location, callee, operands);
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for two builtins:
|
||||
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
|
||||
// functions. Return false on failure.
|
||||
bool mlirGen(PrintExprAST &call) {
|
||||
/// Emit a print expression. It emits specific operations for two builtins:
|
||||
/// transpose(x) and print(x).
|
||||
mlir::LogicalResult mlirGen(PrintExprAST &call) {
|
||||
auto *arg = mlirGen(*call.getArg());
|
||||
if (!arg)
|
||||
return false;
|
||||
auto location = loc(call.loc());
|
||||
builder->create<PrintOp>(location, arg);
|
||||
return true;
|
||||
return mlir::failure();
|
||||
|
||||
builder.create<PrintOp>(loc(call.loc()), arg);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
mlir::Value *mlirGen(NumberExprAST &num) {
|
||||
auto location = loc(num.loc());
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
|
||||
loc(num.loc()));
|
||||
return builder->create<ConstantOp>(location, attr).getResult();
|
||||
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
|
||||
}
|
||||
|
||||
// Dispatch codegen for the right expression subclass using RTTI.
|
||||
/// Dispatch codegen for the right expression subclass using RTTI.
|
||||
mlir::Value *mlirGen(ExprAST &expr) {
|
||||
switch (expr.getKind()) {
|
||||
case toy::ExprAST::Expr_BinOp:
|
||||
|
@ -383,84 +375,82 @@ private:
|
|||
case toy::ExprAST::Expr_Num:
|
||||
return mlirGen(cast<NumberExprAST>(expr));
|
||||
default:
|
||||
emitError(loc(expr.loc()),
|
||||
"MLIR codegen encountered an unhandled expr kind '")
|
||||
emitError(loc(expr.loc()))
|
||||
<< "MLIR codegen encountered an unhandled expr kind '"
|
||||
<< Twine(expr.getKind()) << "'";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
// initializer and record the value in the symbol table before returning it.
|
||||
// Future expressions will be able to reference this variable through symbol
|
||||
// table lookup.
|
||||
/// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
/// initializer and record the value in the symbol table before returning it.
|
||||
/// Future expressions will be able to reference this variable through symbol
|
||||
/// table lookup.
|
||||
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
|
||||
mlir::Value *value = nullptr;
|
||||
auto location = loc(vardecl.loc());
|
||||
if (auto init = vardecl.getInitVal()) {
|
||||
value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder
|
||||
->create<ReshapeOp>(
|
||||
location, value,
|
||||
getType(vardecl.getType()).cast<ToyArrayType>())
|
||||
.getResult();
|
||||
}
|
||||
} else {
|
||||
auto init = vardecl.getInitVal();
|
||||
if (!init) {
|
||||
emitError(loc(vardecl.loc()),
|
||||
"missing initializer in variable declaration");
|
||||
return nullptr;
|
||||
}
|
||||
// Register the value in the symbol table
|
||||
declare(vardecl.getName(), value);
|
||||
|
||||
mlir::Value *value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
|
||||
getType(vardecl.getType()), value);
|
||||
}
|
||||
|
||||
// Register the value in the symbol table.
|
||||
if (failed(declare(vardecl.getName(), value)))
|
||||
return nullptr;
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Codegen a list of expression, return false if one of them hit an error.
|
||||
bool mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
for (auto &expr : blockAST) {
|
||||
// Specific handling for variable declarations, return statement, and
|
||||
// print. These can only appear in block list and not in nested
|
||||
// expressions.
|
||||
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
|
||||
if (!mlirGen(*vardecl))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
continue;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
|
||||
if (!mlirGen(*ret))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
|
||||
return mlirGen(*ret);
|
||||
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
|
||||
if (!mlirGen(*print))
|
||||
return false;
|
||||
if (mlir::failed(mlirGen(*print)))
|
||||
return mlir::success();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generic expression dispatch codegen.
|
||||
if (!mlirGen(*expr))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
}
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a type from a list of shape dimensions. Types are `array` followed
|
||||
/// by an optional dimension list, example: array<2, 2>
|
||||
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
|
||||
/// !toy.array<2, 2>
|
||||
template <typename T> mlir::Type getType(T shape) {
|
||||
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
|
||||
return ToyArrayType::get(&context, shape64);
|
||||
/// Build a tensor type from a list of shape dimensions.
|
||||
mlir::Type getType(ArrayRef<int64_t> shape) {
|
||||
// If the shape is empty, then this type is unranked.
|
||||
if (shape.empty())
|
||||
return builder.getTensorType(builder.getF64Type());
|
||||
|
||||
// Otherwise, we use the given shape.
|
||||
return builder.getTensorType(shape, builder.getF64Type());
|
||||
}
|
||||
|
||||
/// Build an MLIR type from a Toy AST variable type
|
||||
/// (forward to the generic getType(T) above).
|
||||
/// Build an MLIR type from a Toy AST variable type (forward to the generic
|
||||
/// getType above).
|
||||
mlir::Type getType(const VarType &type) { return getType(type.shape); }
|
||||
};
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
|
||||
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,213 +15,55 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Module level pass performing interprocedural
|
||||
// This file implements a Function level pass performing interprocedural
|
||||
// propagation of array shapes through function specialization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
#define DEBUG_TYPE "toy-shape-inference"
|
||||
#define DEBUG_TYPE "shape-inference"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace toy;
|
||||
using llvm::MutableArrayRef;
|
||||
using llvm::SmallVector;
|
||||
using llvm::SmallVectorImpl;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
/// Create mangled name for function specialization. We will simply append the
|
||||
/// shape of the arguments to the function name. For example calling
|
||||
///
|
||||
/// "toy.generic_call"(%1, %3) {callee: "foo"}
|
||||
/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
|
||||
///
|
||||
/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
|
||||
/// have provide a function with a similar name. But we will claim this as a
|
||||
/// feature: this allow the user to provide custom specialization!
|
||||
static std::string mangle(StringRef funcName,
|
||||
MutableArrayRef<mlir::OpOperand> operands) {
|
||||
std::string mangledName;
|
||||
mangledName.reserve(funcName.size() + operands.size() * 6);
|
||||
mangledName = funcName;
|
||||
for (auto &operand : operands) {
|
||||
auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
|
||||
mangledName += "_";
|
||||
const char *sep = "";
|
||||
for (auto dim : arrayTy.getShape()) {
|
||||
mangledName += (sep + Twine(dim)).str();
|
||||
sep = "x";
|
||||
}
|
||||
}
|
||||
return mangledName;
|
||||
}
|
||||
/// Include the auto-generated definitions for the shape inference interfaces.
|
||||
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
|
||||
/// whole. MLIR also supports FunctionPass which are restricted to modify a
|
||||
/// single function at a time. This pass couldn't be a function pass due the
|
||||
/// nature of its interprocedural transformations.
|
||||
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
|
||||
/// shape inference.
|
||||
///
|
||||
/// The algorithm has two levels, first intra-procedurally:
|
||||
/// Algorithm:
|
||||
///
|
||||
/// 1) Build a worklist containing all the operations that are returning
|
||||
/// a generic Toy array: these are the operations that need shape
|
||||
/// 1) Build a worklist containing all the operations that return a
|
||||
/// dynamically shaped tensor: these are the operations that need shape
|
||||
/// inference.
|
||||
/// 2) Iterate on the worklist:
|
||||
/// a) find an operation to process: the next ready operation in the
|
||||
/// worklist has all of its arguments non-generic,
|
||||
/// b) if no operation is found, break out of the loop,
|
||||
/// c) remove the operation from the worklist,
|
||||
/// d) infer the shape of its output from the arguments type.
|
||||
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
|
||||
/// return type for the function from the return operation.
|
||||
/// d) infer the shape of its output from the argument types.
|
||||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
/// There is a twist though: when a call to a generic function is encountered,
|
||||
/// shape inference requires the return type of the callee to be inferred first.
|
||||
/// At this point we need to run specialize the callee by cloning it. Here is
|
||||
/// the inter-procedural flow:
|
||||
///
|
||||
/// 1) Keep a worklist of function to process. Start with function "main".
|
||||
/// 2) While the worklist isn't empty:
|
||||
/// a) Take the last inserted function in the worklist.
|
||||
/// b) Run the intra-procedural shape inference on this function.
|
||||
/// c) If the intra-procedural shape inference can't complete, it returns
|
||||
/// a FuncOp that needs to be inferred first. In this case, queue this
|
||||
/// new function and continue. Otherwise the inference succeeded and we
|
||||
/// can pop from the queue.
|
||||
///
|
||||
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
|
||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||
public:
|
||||
// One entry in the inter-procedural worklist. It keeps track of the
|
||||
// function to process, the mangled name for this specialization, and the
|
||||
// types of the arguments on which to specialize.
|
||||
struct FunctionToSpecialize {
|
||||
mlir::FuncOp function;
|
||||
std::string mangledName;
|
||||
SmallVector<mlir::Type, 4> argumentsType;
|
||||
};
|
||||
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
mlir::ModuleManager moduleManager(module);
|
||||
auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"shape inference failed: can't find a main function\n");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
/// Inter-procedural loop, initialize with `main` and iterate till
|
||||
/// successfully infer the full reachable call-graph from main.
|
||||
SmallVector<FunctionToSpecialize, 8> worklist;
|
||||
worklist.push_back({main, "", {}});
|
||||
while (!worklist.empty()) {
|
||||
if (failed(specialize(worklist, moduleManager)))
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete any generic function left
|
||||
// FIXME: we may want this as a separate pass.
|
||||
for (mlir::FuncOp function :
|
||||
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
|
||||
if (auto genericAttr =
|
||||
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
|
||||
if (genericAttr.getValue())
|
||||
function.erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run inference on a function. If a mangledName is provided, we need to
|
||||
/// specialize the function: to this end clone it first.
|
||||
mlir::LogicalResult
|
||||
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
|
||||
mlir::ModuleManager &moduleManager) {
|
||||
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
|
||||
mlir::FuncOp f = functionToSpecialize.function;
|
||||
|
||||
// Check if cloning for specialization is needed (usually anything but main)
|
||||
// We will create a new function with the concrete types for the parameters
|
||||
// and clone the body into it.
|
||||
if (!functionToSpecialize.mangledName.empty()) {
|
||||
if (moduleManager.lookupSymbol<mlir::FuncOp>(
|
||||
functionToSpecialize.mangledName)) {
|
||||
funcWorklist.pop_back();
|
||||
// FuncOp already specialized, move on.
|
||||
return mlir::success();
|
||||
}
|
||||
// Create a new function with a generic array return type, it will be
|
||||
// updated when the inference for the function body completes.
|
||||
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
|
||||
{ToyArrayType::get(&getContext())},
|
||||
&getContext());
|
||||
auto newFunction =
|
||||
mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
|
||||
type, f.getDialectAttrs());
|
||||
moduleManager.insert(newFunction);
|
||||
|
||||
// Clone the function body
|
||||
mlir::BlockAndValueMapping mapper;
|
||||
f.cloneInto(newFunction, mapper);
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "====== Cloned : \n";
|
||||
f.dump();
|
||||
llvm::dbgs() << "====== Into : \n";
|
||||
newFunction.dump();
|
||||
});
|
||||
f = newFunction;
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
// Remap the entry-block arguments
|
||||
// FIXME: this seems like a bug in `cloneInto()` above?
|
||||
auto &entryBlock = f.getBlocks().front();
|
||||
int blockArgSize = entryBlock.getArguments().size();
|
||||
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
|
||||
entryBlock.addArguments(f.getType().getInputs());
|
||||
auto argList = entryBlock.getArguments();
|
||||
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
|
||||
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
|
||||
entryBlock.eraseArgument(0);
|
||||
}
|
||||
assert(succeeded(verify(f)));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Run shape inference on : '" << f.getName() << "'\n");
|
||||
|
||||
auto *toyDialect = getContext().getRegisteredDialect("toy");
|
||||
if (!toyDialect) {
|
||||
signalPassFailure();
|
||||
return emitError(mlir::UnknownLoc::get(&getContext()),
|
||||
"Toy dialect is not registered");
|
||||
}
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are the Toy operations that return a generic array.
|
||||
// these are operations that return a dynamic shape.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (op->getDialect() == toyDialect) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
|
||||
opWorklist.insert(op);
|
||||
}
|
||||
if (returnsDynamicShape(op))
|
||||
opWorklist.insert(op);
|
||||
});
|
||||
|
||||
// Iterate on the operations in the worklist until all operations have been
|
||||
|
@ -229,152 +71,43 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
|
||||
return !ty.cast<ToyArrayType>().isGeneric();
|
||||
});
|
||||
});
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
if (nextop == opWorklist.end())
|
||||
break; // failure: no operations can be inferred.
|
||||
break;
|
||||
|
||||
mlir::Operation *op = *nextop;
|
||||
Operation *op = *nextop;
|
||||
opWorklist.erase(op);
|
||||
|
||||
// Ask the operation to infer its output shapes.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
||||
|
||||
// The add operation is trivial: propagate the input type as is.
|
||||
if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
|
||||
op->getResult(0)->setType(op->getOperand(0)->getType());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Transpose is easy: just invert the dimensions.
|
||||
if (op->getName().getStringRef() == "toy.transpose") {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(),
|
||||
arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Multiplication is a bit trickier, handle rank 1 as dot product and rank
|
||||
// 2 as matrix multiplications.
|
||||
// We need to be careful about rank mismatch here: the verifier could
|
||||
// catch it but shape inference earlier in the pass could generate an
|
||||
// invalid IR (from an invalid Toy input of course) and we wouldn't want
|
||||
// to crash here.
|
||||
if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
|
||||
auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
|
||||
auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank) {
|
||||
return op->emitError("shape mismatch: LHS and RHS must have the same "
|
||||
"rank for multiplication, got ")
|
||||
<< lhsRank << " vs " << lhsRank;
|
||||
}
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else {
|
||||
if (lhsRank != 2) {
|
||||
return op->emitError("shape mismatch: expect rank 1 or 2 for mul "
|
||||
"operands, got ")
|
||||
<< lhsRank;
|
||||
}
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
}
|
||||
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process calls: lookup the callee after mangling the name with the
|
||||
// argument shapes. If the callee does not exist, we stop the inference
|
||||
// for this function, queue the callee in the inter-procedural work list,
|
||||
// and return. The current function stays in the work list and will
|
||||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
|
||||
if (!callee) {
|
||||
signalPassFailure();
|
||||
return f.emitError("shape inference failed, call to unknown '")
|
||||
<< calleeName << "'";
|
||||
}
|
||||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto mangledCallee =
|
||||
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
funcWorklist.push_back({callee, std::move(mangledName),
|
||||
llvm::to_vector<4>(op->getOperandTypes())});
|
||||
return mlir::success();
|
||||
}
|
||||
// Found a specialized callee! Let's turn this into a normal call
|
||||
// operation.
|
||||
SmallVector<mlir::Value *, 8> operands(op->getOperands());
|
||||
mlir::OpBuilder builder(op);
|
||||
auto newCall =
|
||||
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
|
||||
if (newCall.getNumResults()) {
|
||||
op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
|
||||
op->erase();
|
||||
continue;
|
||||
}
|
||||
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
|
||||
shapeOp.inferShapes();
|
||||
} else {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// Done with inference on this function, removing it from the worklist.
|
||||
funcWorklist.pop_back();
|
||||
// Mark the function as non-generic now that inference has succeeded
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!opWorklist.empty()) {
|
||||
f.emitError("Shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
auto diag = f.emitError("shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
for (auto *ope : opWorklist)
|
||||
diag << " - " << *ope << "\n";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, update the return type of the function based on the argument to
|
||||
// the return operation.
|
||||
for (auto &block : f.getBlocks()) {
|
||||
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
|
||||
if (!ret)
|
||||
continue;
|
||||
if (ret.getNumOperands() &&
|
||||
f.getType().getResult(0) == ret.getOperand()->getType())
|
||||
// type match, we're done
|
||||
break;
|
||||
SmallVector<mlir::Type, 1> retTy;
|
||||
if (ret.getNumOperands())
|
||||
retTy.push_back(ret.getOperand()->getType());
|
||||
std::vector<mlir::Type> argumentsType;
|
||||
for (auto arg : f.getArguments())
|
||||
argumentsType.push_back(arg->getType());
|
||||
auto newType =
|
||||
mlir::FunctionType::get(argumentsType, retTy, &getContext());
|
||||
f.setType(newType);
|
||||
assert(succeeded(verify(f)));
|
||||
break;
|
||||
}
|
||||
return mlir::success();
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createShapeInferencePass() {
|
||||
/// Create a Shape Inference pass.
|
||||
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
|
||||
return std::make_unique<ShapeInferencePass>();
|
||||
}
|
||||
} // namespace toy
|
||||
|
|
|
@ -15,24 +15,30 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple combiner for optimizing pattern in the Toy
|
||||
// dialect.
|
||||
// This file implements a set of simple combiners for optimizing operations in
|
||||
// the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace toy {
|
||||
using namespace mlir;
|
||||
using namespace toy;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
|
@ -40,9 +46,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
@ -55,132 +61,23 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
|
||||
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp reshape,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current reshape.
|
||||
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
|
||||
reshape.getOperand()->getDefiningOp());
|
||||
|
||||
// If the input is defined by another constant, bingo!
|
||||
if (!constantOp)
|
||||
return matchFailure();
|
||||
|
||||
auto reshapeType = reshape.getType().cast<ToyArrayType>();
|
||||
if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
|
||||
// FIXME Check matching of element count!
|
||||
// auto oldType = constantOp.getType();
|
||||
auto newType = rewriter.getTensorType(
|
||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr = valueAttr.reshape(newType);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
|
||||
// Broadcast
|
||||
auto dataSize = std::accumulate(reshapeType.getShape().begin(),
|
||||
reshapeType.getShape().end(), 1,
|
||||
std::multiplies<int>());
|
||||
std::vector<mlir::Attribute> data(dataSize, valueAttr);
|
||||
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
|
||||
reshapeType.getElementType());
|
||||
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else {
|
||||
llvm_unreachable("Unsupported Constant format");
|
||||
}
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(reshape(x)) -> reshape(x)
|
||||
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current reshape.
|
||||
mlir::Value *reshapeInput = op.getOperand();
|
||||
|
||||
// If the input is defined by another reshape, bingo!
|
||||
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {reshapeInput});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(x)) -> x, when input type matches output type
|
||||
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (op.getOperand()->getType() != op.getType())
|
||||
return matchFailure();
|
||||
rewriter.replaceOp(op, {op.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(
|
||||
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
|
||||
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(
|
||||
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
|
||||
results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
|
||||
SimplifyNullReshape>(context);
|
||||
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
||||
FoldConstantReshapeOptPattern>(context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Fold type.cast(x) -> x, when input type matches output type
|
||||
struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
|
||||
using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TypeCastOp typeCast,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resTy = typeCast.getType();
|
||||
auto *candidateOp = typeCast.getOperation();
|
||||
while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
|
||||
if (resTy == candidateOp->getOperand(0)->getType()) {
|
||||
rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
|
||||
return matchSuccess();
|
||||
}
|
||||
candidateOp = candidateOp->getOperand(0)->getDefiningOp();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
void TypeCastOp::getCanonicalizationPatterns(
|
||||
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
|
||||
results.insert<SimplifyIdentityTypeCast>(context);
|
||||
}
|
||||
|
||||
} // namespace toy
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match optimizations for Toy using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TOY_COMBINE
|
||||
#define TOY_COMBINE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "toy/Ops.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Basic Pattern-Match and Rewrite
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Reshape(Reshape(x)) = Reshape(x)
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
||||
(ReshapeOp $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite using Native Code Call
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Native Code Calls may be used for more complex transformations using inline
|
||||
// C++ and C++ helper functions.
|
||||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite with Constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// DRR allows for constraint checking when the transformation is conditional
|
||||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
||||
#endif // TOY_COMBINE
|
|
@ -1,403 +0,0 @@
|
|||
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::raw_string_ostream;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
namespace toy {
|
||||
namespace detail {
|
||||
|
||||
/// This class holds the implementation of the ToyArrayType.
|
||||
/// It is intended to be uniqued based on its content and owned by the context.
|
||||
struct ToyArrayTypeStorage : public mlir::TypeStorage {
|
||||
/// This defines how we unique this type in the context: our key contains
|
||||
/// only the shape, a more complex type would have multiple entries in the
|
||||
/// tuple here.
|
||||
/// The element of the tuples usually matches 1-1 the arguments from the
|
||||
/// public `get()` method arguments from the facade.
|
||||
using KeyTy = std::tuple<ArrayRef<int64_t>>;
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(std::get<0>(key));
|
||||
}
|
||||
/// When the key hash hits an existing type, we compare the shape themselves
|
||||
/// to confirm we have the right type.
|
||||
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
|
||||
|
||||
/// This is a factory method to create our type storage. It is only
|
||||
/// invoked after looking up the type in the context using the key and not
|
||||
/// finding it.
|
||||
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
// Copy the shape array into the bumpptr allocator owned by the context.
|
||||
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
|
||||
|
||||
// Allocate the instance for the ToyArrayTypeStorage itself
|
||||
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
|
||||
// Initialize the instance using placement new.
|
||||
return new (storage) ToyArrayTypeStorage(shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return shape; }
|
||||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
|
||||
/// Constructor is only invoked from the `construct()` method above.
|
||||
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
mlir::Type ToyArrayType::getElementType() {
|
||||
return mlir::FloatType::getF64(getContext());
|
||||
}
|
||||
|
||||
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
|
||||
|
||||
mlir::MemRefType ToyArrayType::toMemref() {
|
||||
auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0);
|
||||
return memRefType;
|
||||
}
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
|
||||
MulOp, AddOp, ReturnOp, AllocOp, TypeCastOp>();
|
||||
addTypes<ToyArrayType>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect, we expect only Toy arrays.
|
||||
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
|
||||
// Sanity check: we only support array or array<...>
|
||||
if (!tyData.startswith("array")) {
|
||||
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
|
||||
return nullptr;
|
||||
}
|
||||
// Drop the "array" prefix from the type name, we expect either an empty
|
||||
// string or just the shape.
|
||||
tyData = tyData.drop_front(StringRef("array").size());
|
||||
// This is the generic array case without shape, early return it.
|
||||
if (tyData.empty())
|
||||
return ToyArrayType::get(getContext());
|
||||
|
||||
// Use a regex to parse the shape (for efficient we should store this regex in
|
||||
// the dialect itself).
|
||||
SmallVector<StringRef, 4> matches;
|
||||
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
|
||||
if (!shapeRegex.match(tyData, &matches)) {
|
||||
emitError(loc, "invalid toy array shape '" + tyData + "'");
|
||||
return nullptr;
|
||||
}
|
||||
SmallVector<int64_t, 4> shape;
|
||||
// Iterate through the captures, skip the first one which is the full string.
|
||||
for (auto dimStr :
|
||||
llvm::make_range(std::next(matches.begin()), matches.end())) {
|
||||
if (dimStr.startswith(","))
|
||||
continue; // POSIX misses non-capturing groups.
|
||||
if (dimStr.empty())
|
||||
continue; // '*' makes it an optional group capture
|
||||
// Convert the capture to an integer
|
||||
unsigned long long dim;
|
||||
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
|
||||
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
|
||||
return mlir::Type();
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
// Finally we collected all the dimensions in the shape,
|
||||
// create the array type.
|
||||
return ToyArrayType::get(getContext(), shape);
|
||||
}
|
||||
|
||||
/// Print a Toy array type, for example `array<2, 3, 4>`
|
||||
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
|
||||
auto arrayTy = type.dyn_cast<ToyArrayType>();
|
||||
if (!arrayTy) {
|
||||
os << "unknown toy type";
|
||||
return;
|
||||
}
|
||||
os << "array";
|
||||
if (!arrayTy.getShape().empty()) {
|
||||
os << "<";
|
||||
mlir::interleaveComma(arrayTy.getShape(), os);
|
||||
os << ">";
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to verify that the result of an operation is a Toy array type.
|
||||
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
|
||||
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getResult()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Helper to verify that the two operands of a binary operation are Toy
|
||||
/// arrays..
|
||||
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
|
||||
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
|
||||
auto dataAttribute = builder->getNamedAttr("value", value);
|
||||
state.attributes.push_back(dataAttribute);
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value) {
|
||||
// Broadcast and forward to the other build factory
|
||||
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
|
||||
auto dataType = builder->getTensorType({1}, elementType);
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
|
||||
ConstantOp::build(builder, state, {1}, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// Ensure that the return type is a Toy array
|
||||
if (failed(verifyToyReturnArray(this)))
|
||||
return mlir::failure();
|
||||
|
||||
// We expect the constant itself to be stored as an attribute.
|
||||
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (!dataAttr) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
|
||||
if (!attrType) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
|
||||
// If the return type of the constant is not a generic array, the shape must
|
||||
// match the shape of the attribute holding the data.
|
||||
auto resultType = getResult()->getType().cast<ToyArrayType>();
|
||||
if (!resultType.isGeneric()) {
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("The rank of the toy.constant return type must match "
|
||||
"the one of the attached value attribute: " +
|
||||
Twine(attrType.getRank()) +
|
||||
" != " + Twine(resultType.getRank()));
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
return emitOpError(
|
||||
"Shape mismatch between toy.constant return type and its "
|
||||
"attribute at dimension " +
|
||||
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
|
||||
" != " + Twine(resultType.getShape()[dim]));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns a generic ToyArray initially
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.assign(arguments.begin(), arguments.end());
|
||||
auto calleeAttr = builder->getStringAttr(callee);
|
||||
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
|
||||
}
|
||||
|
||||
mlir::LogicalResult GenericCallOp::verify() {
|
||||
// Verify that every operand is a Toy Array
|
||||
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
|
||||
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its " << opId << " operand, got "
|
||||
<< getOperand(opId)->getType();
|
||||
return emitOpError(os.str());
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Return the name of the callee.
|
||||
StringRef GenericCallOp::getCalleeName() {
|
||||
return getAttr("callee").cast<mlir::StringAttr>().getValue();
|
||||
}
|
||||
|
||||
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
|
||||
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getOperand()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Return does not return any value and has an optional single argument
|
||||
if (value)
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReturnOp::verify() {
|
||||
if (getNumOperands() > 1)
|
||||
return emitOpError("expects zero or one operand, got " +
|
||||
Twine(getNumOperands()));
|
||||
if (hasOperand() && failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Print does not return any value and has a single argument
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult PrintOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType) {
|
||||
state.types.push_back(reshapedType);
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReshapeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
|
||||
if (!retTy)
|
||||
return emitOpError("toy.reshape is expected to produce a Toy array");
|
||||
if (retTy.isGeneric())
|
||||
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
|
||||
"got a generic one.");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult AddOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void AllocOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Type retType) {
|
||||
state.types.push_back(retType);
|
||||
}
|
||||
|
||||
void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, mlir::Type destTy) {
|
||||
state.operands.push_back(value);
|
||||
state.types.push_back(destTy);
|
||||
}
|
||||
|
||||
} // namespace toy
|
|
@ -20,30 +20,23 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Lowering.h"
|
||||
#include "toy/MLIRGen.h"
|
||||
#include "toy/Parser.h"
|
||||
#include "toy/Passes.h"
|
||||
|
||||
#include "linalg1/Dialect.h"
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/ErrorOr.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace toy;
|
||||
|
@ -64,28 +57,14 @@ static cl::opt<enum InputType> inputType(
|
|||
"load the input file as an MLIR file")));
|
||||
|
||||
namespace {
|
||||
enum Action {
|
||||
None,
|
||||
DumpAST,
|
||||
DumpMLIR,
|
||||
DumpMLIRLinalg,
|
||||
DumpLLVMDialect,
|
||||
DumpLLVMIR,
|
||||
RunJIT
|
||||
};
|
||||
enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine };
|
||||
}
|
||||
static cl::opt<enum Action> emitAction(
|
||||
"emit", cl::desc("Select the kind of output desired"),
|
||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
|
||||
cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg",
|
||||
"output the MLIR dump after linalg lowering")),
|
||||
cl::values(clEnumValN(DumpLLVMDialect, "llvm-dialect",
|
||||
"output the LLVM MLIR Dialect dump")),
|
||||
cl::values(clEnumValN(DumpLLVMIR, "llvm-ir", "output the LLVM IR dump")),
|
||||
cl::values(
|
||||
clEnumValN(RunJIT, "jit",
|
||||
"JIT the code and run it by invoking the main function")));
|
||||
cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
|
||||
"output the MLIR dump after affine lowering")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
|
@ -103,174 +82,81 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
|
||||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerDialect(mlir::ModuleOp module, bool OnlyLinalg) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(createEarlyLoweringPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
if (!OnlyLinalg) {
|
||||
pm.addPass(createLateLoweringPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
}
|
||||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef loadFileAndProcessModule(
|
||||
mlir::MLIRContext &context, bool EnableLinalgLowering = false,
|
||||
bool EnableLLVMLowering = false, bool EnableOpt = false) {
|
||||
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
if (failed(mlir::verify(*module))) {
|
||||
llvm::errs() << "Error verifying MLIR module\n";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
// Handle '.toy' input to the compiler.
|
||||
if (inputType != InputType::MLIR &&
|
||||
!llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
module = mlirGen(context, *moduleAST);
|
||||
return !module ? 1 : 0;
|
||||
}
|
||||
if (!module)
|
||||
return nullptr;
|
||||
if (EnableOpt) {
|
||||
if (failed(optimize(*module))) {
|
||||
llvm::errs() << "Module optimization failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Otherwise, the input is '.mlir'.
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return -1;
|
||||
}
|
||||
if (EnableLLVMLowering || EnableLinalgLowering) {
|
||||
if (failed(lowerDialect(*module, !EnableLLVMLowering))) {
|
||||
llvm::errs() << "Module lowering failed\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Parse the input mlir.
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
}
|
||||
return module;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module =
|
||||
loadFileAndProcessModule(context, /*EnableLinalgLowering=*/false,
|
||||
/*EnableLLVMLowering=*/false, EnableOpt);
|
||||
if (!module)
|
||||
return -1;
|
||||
mlir::OwningModuleRef module;
|
||||
if (int error = loadMLIR(context, module))
|
||||
return error;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Check to see what granularity of MLIR we are compiling to.
|
||||
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
|
||||
|
||||
if (EnableOpt || isLoweringToAffine) {
|
||||
// Inline all functions into main and then delete them.
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
|
||||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
pm.addPass(mlir::toy::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
}
|
||||
|
||||
if (isLoweringToAffine) {
|
||||
// Partially lower the toy dialect with a few cleanups afterwards.
|
||||
pm.addPass(mlir::toy::createLowerToAffinePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
|
||||
// Add optimizations if enabled.
|
||||
if (EnableOpt) {
|
||||
pm.addPass(mlir::createLoopFusionPass());
|
||||
pm.addPass(mlir::createMemRefDataFlowOptPass());
|
||||
}
|
||||
}
|
||||
|
||||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpMLIRLinalg() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(context, /*EnableLinalgLowering=*/true,
|
||||
/*EnableLLVMLowering=*/false,
|
||||
/* EnableOpt=*/true);
|
||||
if (!module)
|
||||
return -1;
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpLLVMDialect() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(
|
||||
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
|
||||
/* EnableOpt=*/true);
|
||||
if (!module) {
|
||||
llvm::errs() << "Failed to load/lower MLIR module\n";
|
||||
return -1;
|
||||
}
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpLLVMIR() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(
|
||||
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
|
||||
/* EnableOpt=*/true);
|
||||
if (!module) {
|
||||
llvm::errs() << "Failed to load/lower MLIR module\n";
|
||||
return -1;
|
||||
}
|
||||
auto llvmModule = translateModuleToLLVMIR(*module);
|
||||
if (!llvmModule) {
|
||||
llvm::errs() << "Failed to emit LLVM IR\n";
|
||||
return -1;
|
||||
}
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0,
|
||||
/* targetMachine=*/nullptr);
|
||||
if (auto err = optPipeline(llvmModule.get())) {
|
||||
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
||||
return -1;
|
||||
}
|
||||
llvm::errs() << *llvmModule << "\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int runJit() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(
|
||||
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
|
||||
/* EnableOpt=*/true);
|
||||
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
||||
// the module.
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(
|
||||
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0,
|
||||
/* targetMachine=*/nullptr);
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline);
|
||||
assert(maybeEngine && "failed to construct an execution engine");
|
||||
auto &engine = maybeEngine.get();
|
||||
|
||||
// Invoke the JIT-compiled function with the arguments. Note that, for API
|
||||
// uniformity reasons, it takes a list of type-erased pointers to arguments.
|
||||
auto invocationResult = engine->invoke("main");
|
||||
if (invocationResult) {
|
||||
llvm::errs() << "JIT invocation failed\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpAST() {
|
||||
if (inputType == InputType::MLIR) {
|
||||
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
|
||||
|
@ -286,10 +172,6 @@ int dumpAST() {
|
|||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Register our Dialects with MLIR
|
||||
mlir::registerDialect<ToyDialect>();
|
||||
mlir::registerDialect<linalg::LinalgDialect>();
|
||||
|
||||
mlir::registerPassManagerCLOptions();
|
||||
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
|
||||
|
||||
|
@ -297,18 +179,10 @@ int main(int argc, char **argv) {
|
|||
case Action::DumpAST:
|
||||
return dumpAST();
|
||||
case Action::DumpMLIR:
|
||||
case Action::DumpMLIRAffine:
|
||||
return dumpMLIR();
|
||||
case Action::DumpMLIRLinalg:
|
||||
return dumpMLIRLinalg();
|
||||
case Action::DumpLLVMDialect:
|
||||
return dumpLLVMDialect();
|
||||
case Action::DumpLLVMIR:
|
||||
return dumpLLVMIR();
|
||||
case Action::RunJIT:
|
||||
return runJit();
|
||||
default:
|
||||
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Chapter 5: CodeGen via Lowering to Lower-Level Dialects
|
||||
# Chapter 5 - Partial Lowering to Lower-Level Dialects for Optimization
|
||||
|
||||
At this point, we are eager to generate actual code and see our Toy language
|
||||
taking life. We will obviously use LLVM to generate code, but just showing the
|
||||
|
@ -6,293 +6,356 @@ LLVM builder interface wouldn't be very exciting here. Instead, we will show how
|
|||
to perform progressive lowering through a mix of dialects coexisting in the same
|
||||
function.
|
||||
|
||||
To make it more interesting, we will consider that we want to reuse existing
|
||||
optimizations implemented in a dialect optimizing linear algebra: `Linalg`. This
|
||||
dialect is tailored to the computation heavy part of the program, and is
|
||||
limited: it doesn't support representing our `toy.print` builtin for instance,
|
||||
neither should it! Instead we can target `Linalg` for the computation heavy part
|
||||
of Toy (mostly matmul), we will target the `Affine` dialect for other
|
||||
well-formed loop nest, and directly the `LLVM IR` dialect for lowering `print`.
|
||||
To make it more interesting, in this chapter we will will consider that we want
|
||||
to reuse existing optimizations implemented in a dialect optimizing affine
|
||||
transformations: `Affine`. This dialect is tailored to the computation-heavy
|
||||
part of the program, and is limited: it doesn't support representing our
|
||||
`toy.print` builtin for instance, neither should it! Instead we can target
|
||||
`Affine` for the computation heavy part of Toy, and in the
|
||||
[next chapter](Ch-6.md) directly the `LLVM IR` dialect for lowering `print`. As
|
||||
part of this lowering, we will be lowering from the
|
||||
[TensorType](../../LangRef.md#tensor-type), that `Toy` operates on, to the
|
||||
[MemRefType](../../LangRef.md#memref-type) that is indexed via an affine
|
||||
loop-nest. Tensors represent an abstract value-typed sequence of data, meaning
|
||||
that they don't live in any memory. MemRefs on the other hand represent lower
|
||||
level buffer access, as they are concrete references to a region of memory.
|
||||
|
||||
# The `DialectConversion` Framework
|
||||
# Dialect Conversions
|
||||
|
||||
Similarly to the canonicalization patterns introduced in the previous section,
|
||||
the `DialectConversion` framework involves its own set of patterns. This
|
||||
framework operates a bit differently from the canonicalizer: a new function is
|
||||
created and the pattern matching operation in the original function are expected
|
||||
to emit the IR in the new function.
|
||||
MLIR contains many different dialects, so it is important to have a unified
|
||||
framework for converting between them. This is where the `DialectConversion`
|
||||
framework comes into play. This framework allows for transforming a set of
|
||||
`illegal` operations to a set of `legal` ones. To use this framework we need to
|
||||
provide two things:
|
||||
|
||||
Dialect conversion requires three components, implemented by overriding virtual
|
||||
methods defined in `DialectConversion`:
|
||||
* A [Conversion Target](../../DialectConversion.md#conversion-target)
|
||||
|
||||
- Type Conversion: for things like block arguments' type.
|
||||
- Function signature conversion: for every function it is invoked with the
|
||||
function type and the conversion generates a new prototype for the converted
|
||||
function. The default implementation will call into the type conversion for
|
||||
the returned values and for each of the parameters.
|
||||
- Operations conversions: each pattern is expected to generate new results
|
||||
matching the current operations' in the new function. This may involve
|
||||
generating one or multiple new operations, or possibly just remapping
|
||||
existing operands (folding).
|
||||
- This is the formal specification of what operations, or dialects, are
|
||||
legal for the conversion. Operations that aren't legal will require
|
||||
rewrite patterns to perform legalization.
|
||||
|
||||
A typical starting point for implementing our lowering would be:
|
||||
* A set of
|
||||
[Rewrite Patterns](../../DialectConversion.md#rewrite-pattern-specification)
|
||||
|
||||
- These are the set of [patterns](../../QuickstartRewrites.md) used to
|
||||
convert `illegal` operations into a set of zero or more `legal` ones.
|
||||
|
||||
* Optionally, A [Type Converter](../../DialectConversion.md#type-conversion).
|
||||
|
||||
- If provided, this is used to convert the types of block arguments. We
|
||||
won't be needing this for our conversion.
|
||||
|
||||
## Conversion Target
|
||||
|
||||
For our purposes, we want to convert the compute intensive `Toy` operations into
|
||||
a combination of operations from the `Affine` `Standard` dialects for further
|
||||
optimization. To start off the lowering, we first define our conversion target:
|
||||
|
||||
```c++
|
||||
class Lowering : public DialectConversion {
|
||||
public:
|
||||
// This gets called for block and region arguments, and attributes.
|
||||
Type convertType(Type t) override { /*...*/ }
|
||||
void ToyToAffineLoweringPass::runOnFunction() {
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
// This gets called for functions.
|
||||
FunctionType convertFunctionSignatureType(FunctionType type,
|
||||
ArrayRef<NamedAttributeList> argAttrs,
|
||||
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) { /*...*/ }
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>();
|
||||
|
||||
// This gets called once to set up operation converters.
|
||||
llvm::DenseSet<ConversionPattern *>
|
||||
initConverters(MLIRContext *context) override {
|
||||
RewriteListBuilder<MulOpConversion, PrintOpConversion,
|
||||
TransposeOpConversion>::build(allocator, context);
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
// a partial lowering, we explicitly mark the Toy operations that don't want
|
||||
// to lower, `toy.print`, as `legal`.
|
||||
target.addIllegalDialect<ToyDialect>();
|
||||
target.addLegalOp<PrintOp>();
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
## Conversion Patterns
|
||||
|
||||
After the conversion target has been defined, we can define how to convert the
|
||||
`illegal` operations into `legal` ones. Similarly to the canonicalization
|
||||
framework introduced in [chapter 3](Ch-3.md), the
|
||||
[`DialectConversion` framework](../../DialectConversion.md) also uses
|
||||
[RewritePatterns](../../QuickstartRewrites.md) to perform the conversion logic.
|
||||
These patterns may be the `RewritePatterns` seen before, or a new type of
|
||||
pattern specific to the conversion framework `ConversionPattern`.
|
||||
`ConversionPatterns` are different from traditional `RewritePatterns` in that
|
||||
they accept an additional `operands` parameter containing operands that have
|
||||
been remapped/replaced. This is used when dealing with type conversions as the
|
||||
pattern will want to operand on values of the new type, but match against the
|
||||
old. For our lowering, this invariant will be useful during our lowering as we
|
||||
will be translating from the [TensorType](../../LangRef.md#tensor-type),
|
||||
currently being operated on, to the [MemRefType](../../LangRef.md#memref-type).
|
||||
Let's look at a snippet of lowering the `toy.transpose` operation:
|
||||
|
||||
```c++
|
||||
/// Lower the `toy.transpose` operation to an affine loop nest.
|
||||
struct TransposeOpLowering : public mlir::ConversionPattern {
|
||||
TransposeOpLowering(mlir::MLIRContext *ctx)
|
||||
: mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
/// Match and rewrite the given `toy.transpose` operation, with the given
|
||||
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value *> operands,
|
||||
mlir::ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Call to a helper function that will lower the current operation to a set
|
||||
// of affine loops. We provide a functor that operates on the remapped
|
||||
// operands, as well as the loop induction variables for the inner most
|
||||
// loop body.
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](mlir::PatternRewriter &rewriter,
|
||||
ArrayRef<mlir::Value *> memRefOperands,
|
||||
ArrayRef<mlir::Value *> loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the TransposeOp.
|
||||
// This allows for using the nice named accessors that are generated
|
||||
// by the ODS. This adaptor is automatically provided by the ODS
|
||||
// framework.
|
||||
TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
|
||||
mlir::Value *input = tranposeAdaptor.input();
|
||||
|
||||
// Transpose the elements by generating a load from the reverse
|
||||
// indices.
|
||||
SmallVector<mlir::Value *, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
```
|
||||
|
||||
Individual operation converters are following this pattern:
|
||||
Now we can prepare the list of patterns to use during the lowering process:
|
||||
|
||||
```c++
|
||||
/// Lower a toy.add to an affine loop nest.
|
||||
///
|
||||
/// This class inherit from `ConversionPattern` and override `rewrite`,
|
||||
/// similarly to the PatternRewriter introduced in the previous chapter.
|
||||
/// It will be called by the DialectConversion framework (see `LateLowering`
|
||||
/// class below).
|
||||
class AddOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit AddOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
|
||||
void ToyToAffineLoweringPass::runOnFunction() {
|
||||
...
|
||||
|
||||
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
|
||||
/// is setup with a new function, the `operands` array has been populated with
|
||||
/// the rewritten operands for `op` in the new function.
|
||||
/// The results created by the new IR with the builder are returned, and their
|
||||
/// number must match the number of result of `op`.
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
OpBuilder &rewriter) const override {
|
||||
...
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
patterns.insert<..., TransposeOpLowering>(&getContext());
|
||||
|
||||
// Return the newly allocated buffer, it will be used as an operand when
|
||||
// converting the operations corresponding to the users of this `toy.add`.
|
||||
return result;
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
## Linalg
|
||||
## Partial Lowering
|
||||
|
||||
Linalg is an advanced dialect for dense algebra optimizations. It is implemented
|
||||
as [a separate tutorial](../Linalg/Ch-1.md) in parallel with Toy. We are acting
|
||||
as a user of this dialect by lowering Toy matrix multiplications to
|
||||
`linalg.matmul`.
|
||||
Once the patterns have been defined, we can perform the actual lowering. The
|
||||
`DialectConversion` framework provides several different modes of lowering, but
|
||||
for our purposes we will be performing a partial lowering, as we will not be
|
||||
converting `toy.print` at this time.
|
||||
|
||||
To support this, we will split our lowering in two parts: an *early lowering*
|
||||
that emits operations in the `Linalg` dialect for a subset of the Toy IR, and a
|
||||
*late lowering* that materializes buffers and converts all operations and type
|
||||
to the LLVM dialect. We will then be able to run specific optimizations in
|
||||
between the two lowering.
|
||||
```c++
|
||||
void ToyToAffineLoweringPass::runOnFunction() {
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
mlir::ConversionTarget target(getContext());
|
||||
|
||||
Let's look again at our example `multiply_transpose`:
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>();
|
||||
|
||||
```mlir
|
||||
func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
|
||||
attributes {toy.generic: true} {
|
||||
%0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
|
||||
%1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
|
||||
"toy.return"(%1) : (!toy.array) -> ()
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
// a partial lowering, we explicitly mark the Toy operations that don't want
|
||||
// to lower, `toy.print`, as `legal`.
|
||||
target.addIllegalDialect<ToyDialect>();
|
||||
target.addLegalOp<PrintOp>();
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
patterns.insert<..., TransposeOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
// operations were not converted successfully.
|
||||
auto function = getFunction();
|
||||
if (mlir::failed(mlir::applyPartialConversion(function, target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
```
|
||||
|
||||
After shape inference, and lowering to `Linalg`, here is what our IR will look
|
||||
like:
|
||||
### Design Considerations With Partial Lowering
|
||||
|
||||
Before diving into the result of our lowering, this is a good time to discuss
|
||||
potential design considerations when it comes to partial lowering. In our
|
||||
lowering, we will be transforming from a value-type, TensorType, to a
|
||||
allocated(buffer-like) type, MemRefType. Given that we will not be lowering the
|
||||
`toy.print` operation, we need to temporarily bridge these two worlds. There are
|
||||
many ways to go about this, each with their own tradeoffs:
|
||||
|
||||
* Generate `load` operations from the buffer
|
||||
|
||||
One option is to generate `load` operations from the buffer type to materialize
|
||||
an instance of the value type. This allows for the definition of the `toy.print`
|
||||
operation to remain unchanged. The downside to this approach is that the
|
||||
optimizations on the `affine` dialect are limited, because the `load` will
|
||||
actually involve a full copy that is only visible *after* our optimizations have
|
||||
been performed.
|
||||
|
||||
* Generate a new version of `toy.print` that operates on the lowered type
|
||||
|
||||
Another option would be to have another, lowered, variant of `toy.print` that
|
||||
operates on the lowered type. The benefit of this option is that there is no
|
||||
hidden, unnecessary, copy to optimizer. The downside is that another operation
|
||||
definition is needed, that may duplicate many aspects of the first. Defining a
|
||||
base class in [ODS](../../OpDefinitions.md) may simplify this, but you still
|
||||
need to treat these operations separately.
|
||||
|
||||
* Update `toy.print` to allow for operating on the lowered type
|
||||
|
||||
A third option is to update the current definition of `toy.print` to allow for
|
||||
operating the on the lowered type. The benefit of this approach is that it is
|
||||
simple, does not introduce an additional hidden copy, and does not require
|
||||
another operation definition. The downside to this option is that it requires
|
||||
mixing abstraction levels in the `Toy` dialect.
|
||||
|
||||
For the sake of simplicity, we will use the third option for this lowering. This
|
||||
involves updating the type constraints on the PrintOp in the operation
|
||||
definition file:
|
||||
|
||||
```tablegen
|
||||
def PrintOp : Toy_Op<"print"> {
|
||||
...
|
||||
|
||||
// The print operation takes an input tensor to print.
|
||||
// We also allow a F64MemRef to enable interop during partial lowering.
|
||||
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
|
||||
}
|
||||
```
|
||||
|
||||
## Complete Toy Example
|
||||
|
||||
Looking back at our current working example:
|
||||
|
||||
```.mlir
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
%3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
"toy.print"(%3) : (tensor<2x3xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
```
|
||||
|
||||
With affine lowering added to our pipeline, we can now generate:
|
||||
|
||||
```mlir
|
||||
func @multiply_transpose_2x3_2x3(%arg0: !toy.array<2, 3>, %arg1: !toy.array<2, 3>) -> !toy.array<2, 2>
|
||||
attributes {toy.generic: false} {
|
||||
%c3 = constant 3 : index
|
||||
func @main() {
|
||||
%c0 = constant 0 : index
|
||||
%c2 = constant 2 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = "toy.transpose"(%arg1) : (!toy.array<2, 3>) -> !toy.array<3, 2>
|
||||
%1 = "toy.alloc"() : () -> !toy.array<2, 2>
|
||||
%2 = "toy.cast"(%1) : (!toy.array<2, 2>) -> memref<2x2xf64>
|
||||
%3 = "toy.cast"(%arg0) : (!toy.array<2, 3>) -> memref<2x3xf64>
|
||||
%4 = "toy.cast"(%0) : (!toy.array<3, 2>) -> memref<3x2xf64>
|
||||
%5 = linalg.range %c0:%c2:%c1 : !linalg.range
|
||||
%6 = linalg.range %c0:%c3:%c1 : !linalg.range
|
||||
%7 = linalg.view %3[%5, %6] : !linalg<"view<?x?xf64>">
|
||||
%8 = linalg.view %4[%6, %5] : !linalg<"view<?x?xf64>">
|
||||
%9 = linalg.view %2[%5, %5] : !linalg<"view<?x?xf64>">
|
||||
linalg.matmul(%7, %8, %9) : !linalg<"view<?x?xf64>">
|
||||
"toy.return"(%1) : (!toy.array<2, 2>) -> ()
|
||||
}
|
||||
```
|
||||
%c2 = constant 2 : index
|
||||
%cst = constant 1.000000e+00 : f64
|
||||
%cst_0 = constant 2.000000e+00 : f64
|
||||
%cst_1 = constant 3.000000e+00 : f64
|
||||
%cst_2 = constant 4.000000e+00 : f64
|
||||
%cst_3 = constant 5.000000e+00 : f64
|
||||
%cst_4 = constant 6.000000e+00 : f64
|
||||
|
||||
Note how the operations from multiple dialects are coexisting in this function.
|
||||
// Allocating buffers for the inputs and outputs.
|
||||
%0 = alloc() : memref<2x3xf64>
|
||||
%1 = alloc() : memref<2x3xf64>
|
||||
%2 = alloc() : memref<2x3xf64>
|
||||
|
||||
You can reproduce this result with `bin/toyc-ch5
|
||||
test/Examples/Toy/Ch5/lowering.toy -emit=mlir-linalg`
|
||||
// Initialize the input buffer with the constant values.
|
||||
affine.store %cst, %2[%c0, %c0] : memref<2x3xf64>
|
||||
affine.store %cst_0, %2[%c0, %c1] : memref<2x3xf64>
|
||||
affine.store %cst_1, %2[%c0, %c2] : memref<2x3xf64>
|
||||
affine.store %cst_2, %2[%c1, %c0] : memref<2x3xf64>
|
||||
affine.store %cst_3, %2[%c1, %c1] : memref<2x3xf64>
|
||||
affine.store %cst_4, %2[%c1, %c2] : memref<2x3xf64>
|
||||
|
||||
## Emitting LLVM
|
||||
|
||||
The availability of various dialects allows for a smooth lowering by reducing
|
||||
the impedance mismatch between dialects. For example we don't need to lower our
|
||||
`toy.print` over array directly to LLVM IR, we can use the well structured loop
|
||||
from the `Affine` dialect for convenience when scanning the array and insert a
|
||||
call to `llvm.printf` in the body. We will rely on MLIR lowering to LLVM for the
|
||||
`Affine` dialect, we get it for free. Here is a simplified version of the code
|
||||
in this chapter for lowering `toy.print`:
|
||||
|
||||
```c++
|
||||
// Create our loop nest now
|
||||
using namespace edsc;
|
||||
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
|
||||
ScopedContext scope(rewriter, loc);
|
||||
ValueHandle zero = intrinsics::constant_index(0);
|
||||
ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
|
||||
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
|
||||
MemRefView vOp(operand);
|
||||
IndexedValue iOp(operand);
|
||||
IndexHandle i, j, M(vOp.ub(0)), N(vOp.ub(1));
|
||||
LoopBuilder(&i, zero, M, 1)({
|
||||
LoopBuilder(&j, zero, N, 1)({
|
||||
llvmCall(retTy,
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{fmtCst, iOp(i, j)})
|
||||
}),
|
||||
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol})
|
||||
});
|
||||
```
|
||||
|
||||
For instance the Toy IR may contain:
|
||||
|
||||
```
|
||||
"toy.print"(%0) : (!toy.array<2, 2>) -> ()
|
||||
```
|
||||
|
||||
which the converter above will turn into this sequence:
|
||||
|
||||
```mlir
|
||||
affine.for %i0 = 0 to 2 {
|
||||
affine.for %i1 = 0 to 2 {
|
||||
%3 = load %0[%i0, %i1] : memref<2x2xf64>
|
||||
%4 = llvm.call @printf(%1, %3) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
|
||||
// Load the transpose value from the input buffer and store it into the
|
||||
// next input buffer.
|
||||
affine.for %arg0 = 0 to 2 {
|
||||
affine.for %arg1 = 0 to 3 {
|
||||
%3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
|
||||
affine.store %3, %1[%arg0, %arg1] : memref<2x3xf64>
|
||||
}
|
||||
%5 = llvm.call @printf(%2, %cst_21) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
|
||||
}
|
||||
|
||||
// Multiply and store into the output buffer.
|
||||
affine.for %arg0 = 0 to 2 {
|
||||
affine.for %arg1 = 0 to 3 {
|
||||
%3 = affine.load %1[%arg0, %arg1] : memref<2x3xf64>
|
||||
%4 = affine.load %1[%arg0, %arg1] : memref<2x3xf64>
|
||||
%5 = mulf %3, %4 : f64
|
||||
affine.store %5, %0[%arg0, %arg1] : memref<2x3xf64>
|
||||
}
|
||||
}
|
||||
|
||||
// Print the value held by the buffer.
|
||||
"toy.print"(%0) : (memref<2x3xf64>) -> ()
|
||||
dealloc %2 : memref<2x3xf64>
|
||||
dealloc %1 : memref<2x3xf64>
|
||||
dealloc %0 : memref<2x3xf64>
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
Note the mix of a loop nest in the `Affine` dialect, with an operation
|
||||
`llvm.call` in the body. MLIR knows already how to lower this to:
|
||||
## Taking Advantage of Affine Optimization
|
||||
|
||||
Our naive lowering is correct, but it leaves a lot to be desired in regards to
|
||||
efficiency; For example the lowering of `toy.mul` has generated some redundant
|
||||
loads. Let's look at how adding a few existing optimizations to the pipeline can
|
||||
help clean this up. Adding the `LoopFusion` and `MemRefDataFlowOpt` passes to
|
||||
the pipeline gives the following result:
|
||||
|
||||
```mlir
|
||||
llvm.br ^bb1(%87 : !llvm.i64)
|
||||
^bb1(%89: !llvm.i64): // 2 preds: ^bb0, ^bb5
|
||||
%90 = llvm.icmp "slt" %89, %88 : !llvm.i64
|
||||
llvm.cond_br %90, ^bb2, ^bb6
|
||||
^bb2: // pred: ^bb1
|
||||
%91 = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
%92 = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
llvm.br ^bb3(%91 : !llvm.i64)
|
||||
^bb3(%93: !llvm.i64): // 2 preds: ^bb2, ^bb4
|
||||
%94 = llvm.icmp "slt" %93, %92 : !llvm.i64
|
||||
llvm.cond_br %94, ^bb4, ^bb5
|
||||
^bb4: // pred: ^bb3
|
||||
%95 = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
%96 = llvm.mlir.constant(2 : index) : !llvm.i64
|
||||
%97 = llvm.mul %89, %96 : !llvm.i64
|
||||
%98 = llvm.add %97, %93 : !llvm.i64
|
||||
%99 = llvm.getelementptr %6[%98] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
|
||||
%100 = llvm.load %99 : !llvm<"double*">
|
||||
%101 = llvm.call @printf(%48, %100) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
|
||||
%102 = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
%103 = llvm.add %93, %102 : !llvm.i64
|
||||
llvm.br ^bb3(%103 : !llvm.i64)
|
||||
^bb5: // pred: ^bb3
|
||||
%104 = llvm.call @printf(%76, %71) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
|
||||
%105 = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
%106 = llvm.add %89, %105 : !llvm.i64
|
||||
llvm.br ^bb1(%106 : !llvm.i64)
|
||||
```
|
||||
func @main() {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%cst = constant 1.000000e+00 : f64
|
||||
%cst_0 = constant 2.000000e+00 : f64
|
||||
%cst_1 = constant 3.000000e+00 : f64
|
||||
%cst_2 = constant 4.000000e+00 : f64
|
||||
%cst_3 = constant 5.000000e+00 : f64
|
||||
%cst_4 = constant 6.000000e+00 : f64
|
||||
|
||||
We appreciate the ease to generate the former, as well as the readability!
|
||||
// Allocating buffers for the inputs and outputs.
|
||||
%0 = alloc() : memref<2x3xf64>
|
||||
%1 = alloc() : memref<2x3xf64>
|
||||
|
||||
You may reproduce these results with `echo "def main() { print([[1,2],[3,4]]); }
|
||||
" | bin/toyc-ch5 -x toy - -emit=llvm-dialect` and `echo "def main() {
|
||||
print([[1,2],[3,4]]); } " | bin/toyc-ch5 -x toy - -emit=llvm-ir`.
|
||||
// Initialize the input buffer with the constant values.
|
||||
affine.store %cst, %1[%c0, %c0] : memref<2x3xf64>
|
||||
affine.store %cst_0, %1[%c0, %c1] : memref<2x3xf64>
|
||||
affine.store %cst_1, %1[%c0, %c2] : memref<2x3xf64>
|
||||
affine.store %cst_2, %1[%c1, %c0] : memref<2x3xf64>
|
||||
affine.store %cst_3, %1[%c1, %c1] : memref<2x3xf64>
|
||||
affine.store %cst_4, %1[%c1, %c2] : memref<2x3xf64>
|
||||
|
||||
# CodeGen: Getting Out of MLIR
|
||||
affine.for %arg0 = 0 to 2 {
|
||||
affine.for %arg1 = 0 to 3 {
|
||||
// Load the transpose value from the input buffer.
|
||||
%2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>
|
||||
|
||||
At this point, all the IR is expressed in the LLVM dialect, MLIR can perform a
|
||||
straight conversion to an LLVM module. You may look into
|
||||
[`Ch5/toyc.cpp`](../../../examples/toy/Ch5/toyc.cpp) for the `dumpLLVM()`
|
||||
function:
|
||||
|
||||
```c++
|
||||
int dumpLLVM() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
|
||||
auto llvmModule = translateModuleToLLVMIR(*module);
|
||||
if (!llvmModule) {
|
||||
llvm::errs() << "Failed to emit LLVM IR\n";
|
||||
return -1;
|
||||
// Multiply and store into the output buffer.
|
||||
%3 = mulf %2, %2 : f64
|
||||
affine.store %3, %0[%arg0, %arg1] : memref<2x3xf64>
|
||||
}
|
||||
}
|
||||
llvm::errs() << *llvmModule << "\n";
|
||||
return 0;
|
||||
|
||||
// Print the value held by the buffer.
|
||||
"toy.print"(%0) : (memref<2x3xf64>) -> ()
|
||||
dealloc %1 : memref<2x3xf64>
|
||||
dealloc %0 : memref<2x3xf64>
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
Adding a JIT isn't much more involved either:
|
||||
Here we can see that an allocation was removed, the two loop nests were fused,
|
||||
and we also were able to remove an unnecessary allocation! You can build
|
||||
`toyc-ch5` and try yourself: `toyc-ch5 test/lowering.toy -emit=mlir-affine`. We
|
||||
can also check our optimizations by adding `-opt`.
|
||||
|
||||
```c++
|
||||
int runJit() {
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
|
||||
|
||||
// Initialize LLVM targets.
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
||||
// Create an MLIR execution engine. Note that it takes a null pass manager
|
||||
// to make sure it won't run "default" passes on the MLIR that would trigger
|
||||
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
|
||||
// the module.
|
||||
auto maybeEngine =
|
||||
mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr);
|
||||
assert(maybeEngine && "failed to construct an execution engine");
|
||||
auto &engine = maybeEngine.get();
|
||||
|
||||
// Invoke the JIT-compiled function with the arguments. Note that, for API
|
||||
// uniformity reasons, it takes a list of type-erased pointers to arguments.
|
||||
auto invocationResult = engine->invoke("main");
|
||||
if(invocationResult) {
|
||||
llvm::errs() << "JIT invocation failed\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
```
|
||||
|
||||
You can play with it, from the build directory:
|
||||
|
||||
```bash
|
||||
$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch5 -emit=jit
|
||||
1.000000 2.000000
|
||||
3.000000 4.000000
|
||||
```
|
||||
|
||||
You can also play with `-emit=mlir`, `-emit=mlir-linalg`, `-emit=llvm-dialect`,
|
||||
and `-emit=llvm-ir` to compare the various level of IR involved. Try also
|
||||
options like `--print-ir-after-all` to track the evolution of the IR throughout
|
||||
the pipeline.
|
||||
In this chapter we explored some aspects of partial lowering, with the intent to
|
||||
optimize. In the [next chapter](Ch-6.md) we will continue the discussion about
|
||||
dialect conversion by targeting LLVM for code generation.
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# RUN: toyc-ch4 %s -emit=ast 2>&1 | FileCheck %s
|
||||
|
||||
|
||||
# User defined generic function that operates solely on
|
||||
# User defined generic function that operates on unknown shaped arguments.
|
||||
def multiply_transpose(a, b) {
|
||||
return a * transpose(b);
|
||||
}
|
||||
|
@ -11,8 +10,10 @@ def main() {
|
|||
# The shape is inferred from the supplied literal.
|
||||
var a = [[1, 2, 3], [4, 5, 6]];
|
||||
# b is identical to a, the literal array is implicitly reshaped: defining new
|
||||
# variables is the way to reshape arrays (element count must match).
|
||||
# variables is the way to reshape arrays (element count in literal must match
|
||||
# the size of specified shape).
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
|
||||
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||
# arguments and deduce a return type of <2, 2> in initialization of `c`.
|
||||
var c = multiply_transpose(a, b);
|
||||
|
@ -30,44 +31,44 @@ def main() {
|
|||
|
||||
# CHECK: Module:
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1'
|
||||
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1'
|
||||
# CHECK-NEXT: Params: [a, b]
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: Retur
|
||||
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24
|
||||
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: } // Block
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1'
|
||||
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1'
|
||||
# CHECK-NEXT: Params: []
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3
|
||||
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11
|
||||
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3
|
||||
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11
|
||||
# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3
|
||||
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17
|
||||
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33
|
||||
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33
|
||||
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44
|
||||
# CHECK-NEXT: ]
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s
|
||||
// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
%3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
"toy.print"(%3) : (tensor<2x3xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main()
|
||||
// CHECK: [[VAL_0:%.*]] = constant 1.000000e+00 : f64
|
||||
// CHECK: [[VAL_1:%.*]] = constant 2.000000e+00 : f64
|
||||
// CHECK: [[VAL_2:%.*]] = constant 3.000000e+00 : f64
|
||||
// CHECK: [[VAL_3:%.*]] = constant 4.000000e+00 : f64
|
||||
// CHECK: [[VAL_4:%.*]] = constant 5.000000e+00 : f64
|
||||
// CHECK: [[VAL_5:%.*]] = constant 6.000000e+00 : f64
|
||||
// CHECK: [[VAL_6:%.*]] = alloc() : memref<2x3xf64>
|
||||
// CHECK: [[VAL_7:%.*]] = alloc() : memref<2x3xf64>
|
||||
// CHECK: [[VAL_8:%.*]] = alloc() : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64>
|
||||
// CHECK: affine.for [[VAL_9:%.*]] = 0 to 2 {
|
||||
// CHECK: affine.for [[VAL_10:%.*]] = 0 to 3 {
|
||||
// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64>
|
||||
// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64>
|
||||
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 2 {
|
||||
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 3 {
|
||||
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
|
||||
// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
|
||||
// CHECK: [[VAL_16:%.*]] = mulf [[VAL_14]], [[VAL_15]] : f64
|
||||
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
|
||||
// CHECK: "toy.print"([[VAL_6]]) : (memref<2x3xf64>) -> ()
|
||||
// CHECK: dealloc [[VAL_8]] : memref<2x3xf64>
|
||||
// CHECK: dealloc [[VAL_7]] : memref<2x3xf64>
|
||||
// CHECK: dealloc [[VAL_6]] : memref<2x3xf64>
|
||||
|
||||
// OPT-LABEL: func @main()
|
||||
// OPT: [[VAL_1:%.*]] = constant 1.000000e+00 : f64
|
||||
// OPT: [[VAL_2:%.*]] = constant 2.000000e+00 : f64
|
||||
// OPT: [[VAL_3:%.*]] = constant 3.000000e+00 : f64
|
||||
// OPT: [[VAL_4:%.*]] = constant 4.000000e+00 : f64
|
||||
// OPT: [[VAL_5:%.*]] = constant 5.000000e+00 : f64
|
||||
// OPT: [[VAL_6:%.*]] = constant 6.000000e+00 : f64
|
||||
// OPT: [[VAL_7:%.*]] = alloc() : memref<2x3xf64>
|
||||
// OPT: [[VAL_8:%.*]] = alloc() : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_1]], [[VAL_8]]{{\[}}0, 0] : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_2]], [[VAL_8]]{{\[}}0, 1] : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_3]], [[VAL_8]]{{\[}}0, 2] : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_4]], [[VAL_8]]{{\[}}1, 0] : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_5]], [[VAL_8]]{{\[}}1, 1] : memref<2x3xf64>
|
||||
// OPT: affine.store [[VAL_6]], [[VAL_8]]{{\[}}1, 2] : memref<2x3xf64>
|
||||
// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 {
|
||||
// OPT: affine.for [[VAL_10:%.*]] = 0 to 3 {
|
||||
// OPT: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64>
|
||||
// OPT: [[VAL_12:%.*]] = mulf [[VAL_11]], [[VAL_11]] : f64
|
||||
// OPT: affine.store [[VAL_12]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64>
|
||||
// OPT: "toy.print"([[VAL_7]]) : (memref<2x3xf64>) -> ()
|
||||
// OPT: dealloc [[VAL_8]] : memref<2x3xf64>
|
||||
// OPT: dealloc [[VAL_7]] : memref<2x3xf64>
|
|
@ -1,7 +1,6 @@
|
|||
# RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s
|
||||
|
||||
|
||||
# User defined generic function that operates solely on
|
||||
# User defined generic function that operates on unknown shaped arguments.
|
||||
def multiply_transpose(a, b) {
|
||||
return a * transpose(b);
|
||||
}
|
||||
|
@ -10,9 +9,11 @@ def main() {
|
|||
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||
# The shape is inferred from the supplied literal.
|
||||
var a = [[1, 2, 3], [4, 5, 6]];
|
||||
# b is identical to a, the literal array is implicitely reshaped: defining new
|
||||
# variables is the way to reshape arrays (element count must match).
|
||||
# b is identical to a, the literal array is implicitly reshaped: defining new
|
||||
# variables is the way to reshape arrays (element count in literal must match
|
||||
# the size of specified shape).
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
|
||||
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||
# arguments and deduce a return type of <2, 2> in initialization of `c`.
|
||||
var c = multiply_transpose(a, b);
|
||||
|
@ -30,44 +31,44 @@ def main() {
|
|||
|
||||
# CHECK: Module:
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1'
|
||||
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1'
|
||||
# CHECK-NEXT: Params: [a, b]
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: Retur
|
||||
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24
|
||||
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: } // Block
|
||||
# CHECK-NEXT: Function
|
||||
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1'
|
||||
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1'
|
||||
# CHECK-NEXT: Params: []
|
||||
# CHECK-NEXT: Block {
|
||||
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3
|
||||
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11
|
||||
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3
|
||||
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11
|
||||
# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3
|
||||
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17
|
||||
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33
|
||||
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33
|
||||
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11
|
||||
# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40
|
||||
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3
|
||||
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11
|
||||
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30
|
||||
# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40
|
||||
# CHECK-NEXT: ]
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44
|
||||
# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44
|
||||
# CHECK-NEXT: ]
|
||||
|
||||
|
|
|
@ -13,20 +13,19 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
|
||||
# CHECK-NEXT: attributes {toy.generic = true} {
|
||||
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
|
||||
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
|
||||
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
|
||||
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-LABEL: func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
|
||||
# CHECK-NEXT: attributes {toy.generic} {
|
||||
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
// RUN: not toyc-ch5 %s -emit=mlir 2>&1
|
||||
|
||||
|
||||
// This IR is not "valid":
|
||||
// The following IR is not "valid":
|
||||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
// This all round-trip since this is opaque for MLIR.
|
||||
func @main() {
|
||||
%0 = "toy.print"() : () -> !toy.array<2, 3>
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# RUN: toyc-ch5 %s -emit=llvm-ir 2>&1 | FileCheck %s
|
||||
|
||||
# User defined generic function that operates on unknown shaped arguments
|
||||
def multiply_transpose(a, b) {
|
||||
return a * transpose(b);
|
||||
}
|
||||
|
||||
# CHECK: define void @main() {
|
||||
# CHECK: %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6))
|
||||
def main() {
|
||||
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
var c = multiply_transpose(a, b);
|
||||
var d = multiply_transpose(b, a);
|
||||
print(d);
|
||||
}
|
|
@ -6,9 +6,9 @@ def main() {
|
|||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s
|
||||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
"toy.return"(%1) : (tensor<*xf64>) -> ()
|
||||
}
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
|
||||
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
|
||||
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
"toy.print"(%5) : (tensor<*xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @multiply_transpose
|
||||
// CHECK-NOT: tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: func @main() {
|
||||
// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
|
||||
// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64>
|
||||
// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> ()
|
||||
// CHECK: "toy.return"() : () -> ()
|
|
@ -1,19 +0,0 @@
|
|||
# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
|
||||
# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
def transpose_transpose(x) {
|
||||
return transpose(transpose(x));
|
||||
}
|
||||
|
||||
def main() {
|
||||
print(transpose_transpose([[1, 2], [3, 4]]));
|
||||
}
|
||||
|
||||
#CHECK-LABEL: func @transpose_transpose
|
||||
#CHECK: transpose
|
||||
#CHECK-LABEL: main
|
||||
|
||||
|
||||
#OPT-LABEL: func @transpose_transpose
|
||||
#OPT-NOT: transpose
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
|
||||
# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
# We expect no reshape in this function with optimizations enabled
|
||||
def foo(a) {
|
||||
var b<2,1> = a;
|
||||
var c<2,1> = b;
|
||||
print(c);
|
||||
}
|
||||
|
||||
def main() {
|
||||
var a<2, 1> = [1, 2];
|
||||
foo(a);
|
||||
}
|
||||
|
||||
# without optimizations, match the reshape
|
||||
#CHECK-LABEL: func @foo
|
||||
#CHECK: reshape
|
||||
#CHECK-LABEL: main
|
||||
|
||||
# with optimizations, ensure no reshape
|
||||
#OPT-LABEL: main
|
||||
#OPT-LABEL: func @foo_2x1
|
||||
#OPT-NOT: reshape
|
Loading…
Reference in New Issue