[mlir][toy] Define a FuncOp operation in toy and drop the dependence on FuncOp

FuncOp is being moved out of the builtin dialect, and defining a custom
toy operation showcases various aspects of defining function-like operation
(e.g. inlining, passes, etc.).

Differential Revision: https://reviews.llvm.org/D121264
This commit is contained in:
River Riddle 2022-03-08 16:21:07 -08:00
parent f96a8675cd
commit ee2c6cd906
80 changed files with 946 additions and 259 deletions

View File

@ -120,7 +120,8 @@ types, to be customized. At the same time, IR elements can always be reduced to
the above fundamental concepts. This allows MLIR to parse, represent, and
[round-trip](../../../getting_started/Glossary.md/#round-trip) IR for *any*
operation. For example, we could place our Toy operation from above into an
`.mlir` file and round-trip through *mlir-opt* without registering any dialect:
`.mlir` file and round-trip through *mlir-opt* without registering any `toy`
related dialect:
```mlir
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
@ -558,13 +559,14 @@ Results in the following IR:
```mlir
module {
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
"toy.func"() ({
^bb0(%arg0: tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":4:1), %arg1: tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":4:1)):
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:10)
%1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
%2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
"toy.return"(%2) : (tensor<*xf64>) -> () loc("test/Examples/Toy/Ch2/codegen.toy":5:3)
} loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
func @main() {
}) {sym_name = "multiply_transpose", type = (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>} : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
"toy.func"() ({
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:17)
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:3)
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> loc("test/Examples/Toy/Ch2/codegen.toy":10:17)
@ -573,7 +575,7 @@ module {
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":12:11)
"toy.print"(%5) : (tensor<*xf64>) -> () loc("test/Examples/Toy/Ch2/codegen.toy":13:3)
"toy.return"() : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
} loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
}) {sym_name = "main", type = () -> ()} : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
} loc(unknown)
```
@ -686,13 +688,13 @@ now get a much more readable:
```mlir
module {
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:10)
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
%2 = toy.mul %0, %1 : tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
toy.return %2 : tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:3)
} loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:17)
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:3)
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> loc("test/Examples/Toy/Ch2/codegen.toy":10:17)

View File

@ -37,7 +37,7 @@ def transpose_transpose(x) {
Which corresponds to the following IR:
```mlir
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
toy.return %1 : tensor<*xf64>
@ -125,14 +125,14 @@ similar way to LLVM:
```c++
mlir::PassManager pm(module.getContext());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
```
Finally, we can run `toyc-ch3 test/Examples/Toy/Ch3/transpose_transpose.toy
-emit=mlir -opt` and observe our pattern in action:
```mlir
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
toy.return %arg0 : tensor<*xf64>
}
@ -153,7 +153,7 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...}
Let's retry now `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt`:
```mlir
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
toy.return %arg0 : tensor<*xf64>
}
```
@ -228,7 +228,7 @@ def main() {
```mlir
module {
func @main() {
toy.func @main() {
%0 = toy.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
%1 = toy.reshape(%0 : tensor<2xf64>) to tensor<2x1xf64>
%2 = toy.reshape(%1 : tensor<2x1xf64>) to tensor<2x1xf64>
@ -244,7 +244,7 @@ We can try to run `toyc-ch3 test/Examples/Toy/Ch3/trivial_reshape.toy -emit=mlir
```mlir
module {
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf64>
toy.print %0 : tensor<2x1xf64>
toy.return

View File

@ -77,6 +77,14 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}
/// This hook cheks if the given 'src' region can be inlined into the 'dest'
/// region. The regions here are the bodies of the callable functions. For
/// Toy, any function can be inlined, so we simply return true.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
/// This hook is called when a terminator operation has been inlined. The only
/// terminator that we have in the Toy dialect is the return
/// operation(toy.return). We handle the return by replacing the values
@ -101,7 +109,7 @@ main function) in the MLIR generator.
```c++
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
...
// If this function isn't main, then set the visibility to private.
if (funcAST.getProto()->getName() != "main")
@ -121,12 +129,12 @@ void ToyDialect::initialize() {
```
Next, we need to provide a way for the inliner to know that `toy.generic_call`
represents a call to a function. MLIR provides an
[operation interface](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
to mark an operation as being "call-like". Unlike dialect interfaces, operation
interfaces provide a more refined granularity of information that is specific
and core to a single operation. The interface that we will be adding here is the
`CallOpInterface`.
represents a call, and `toy.func` represents a function. MLIR provides
[operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
to mark an operation as being "call-like" or "callable-like". Unlike dialect interfaces,
operation interfaces provide a more refined granularity of information that is specific
and core to a single operation. The interfaces that we will be adding here is the
`CallOpInterface` and `CallableOpInterface`.
To add this interface we just need to include the definition into our operation
specification file (`Ops.td`):
@ -138,6 +146,11 @@ include "mlir/Interfaces/CallInterfaces.td"
and add it to the traits list of `GenericCallOp`:
```tablegen
def FuncOp : Toy_Op<"func",
[DeclareOpInterfaceMethods<CallableOpInterface>]> {
...
}
def GenericCallOp : Toy_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
...
@ -149,6 +162,15 @@ auto-declare all of the interface methods in the class declaration of
GenericCallOp. This means that we just need to provide a definition:
```c++
/// Returns the region on the function operation that is callable.
Region *FuncOp::getCallableRegion() { return &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
ArrayRef<Type> FuncOp::getCallableResults() { return getType().getResults(); }
// ....
/// Return the callee of the generic call operation, this is required by the
/// call interface.
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
@ -170,13 +192,13 @@ inliner pass to the pass manager for Toy:
Now let's look at a working example:
```mlir
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@ -214,6 +236,7 @@ def CastOp : Toy_Op<"cast", [
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
```
@ -263,14 +286,14 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
If we run the working example through the pipeline again, we get the expected:
```mlir
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%1 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%2 = "toy.cast"(%1) : (tensor<2x3xf64>) -> tensor<*xf64>
%3 = "toy.cast"(%0) : (tensor<2x3xf64>) -> tensor<*xf64>
%4 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64>
%5 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64>
%6 = "toy.mul"(%4, %5) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64>
%3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64>
%4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
%5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
%6 = toy.mul %4, %5 : tensor<*xf64>
toy.print %6 : tensor<*xf64>
toy.return
}
@ -357,10 +380,10 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
At this point, each of the necessary Toy operations provide a mechanism by which
to infer their output shapes. The ShapeInferencePass will operate on functions:
it will run on each Function in isolation. MLIR also supports general
it will run on each function in isolation. MLIR also supports general
[OperationPasses](../../PassManagement.md#operation-pass) that run on any
isolated operation (i.e. other function-like operations), but here our module
only contains functions, so there is no need to generalize to all operations.
isolated operation, but here our module only contains functions, so there is no
need to generalize to all operations.
Implementing such a pass is done by creating a class inheriting from
`mlir::OperationPass<FuncOp>` and overriding the `runOnOperation()` method.
@ -421,10 +444,10 @@ We can then add our pass to the pass manager:
If we rerun our original example, we now get the following:
```mlir
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
%2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%2 = toy.mul %1, %1 : tensor<3x2xf64>
toy.print %2 : tensor<3x2xf64>
toy.return
}

View File

@ -172,8 +172,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our *illegal*
// operations were not converted successfully.
mlir::FuncOp function = getOperation();
if (mlir::failed(mlir::applyPartialConversion(function, target, patterns)))
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}
```
@ -232,7 +231,7 @@ def PrintOp : Toy_Op<"print"> {
Let's take a concrete example:
```mlir
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -119,7 +119,7 @@ that only legal operations will remain after the conversion.
Looking back at our current working example:
```mlir
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -327,7 +327,7 @@ Which generates the following:
```mlir
module {
func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
toy.return
}
}
@ -405,7 +405,7 @@ and finally get a full MLIR module:
```mlir
module {
func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
%0 = toy.struct_access %arg0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
%1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.struct_access %arg0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
@ -413,7 +413,7 @@ module {
%4 = toy.mul %1, %3 : tensor<*xf64>
toy.return %4 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.struct_constant [
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
@ -434,7 +434,7 @@ After inlining, the MLIR module in the previous section looks something like:
```mlir
module {
func @main() {
toy.func @main() {
%0 = toy.struct_constant [
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
@ -500,7 +500,7 @@ changes to our pipeline.
```mlir
module {
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%2 = toy.mul %1, %1 : tensor<3x2xf64>

View File

@ -14,8 +14,10 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
/// Include the auto-generated header file containing the declaration of the toy

View File

@ -14,6 +14,8 @@
#define TOY_OPS
include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
// Provide a definition of the 'toy' dialect in the ODS framework so that we
@ -106,6 +108,61 @@ def AddOp : Toy_Op<"add"> {
];
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@ -187,6 +188,39 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::SymbolRefAttr::get(builder.getContext(), callee));
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@ -58,12 +58,8 @@ public:
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -108,7 +104,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -116,23 +112,23 @@ private:
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.

View File

@ -14,8 +14,10 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
/// Include the auto-generated header file containing the declaration of the toy

View File

@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
// Provide a definition of the 'toy' dialect in the ODS framework so that we
@ -105,6 +107,61 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
];
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
FunctionOpInterface, IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@ -174,6 +175,39 @@ mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -58,12 +58,8 @@ public:
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -108,7 +104,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -116,23 +112,23 @@ private:
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.

View File

@ -118,7 +118,7 @@ int dumpMLIR() {
applyPassManagerCLOptions(pm);
// Add a run of the canonicalizer to optimize the mlir module.
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module)))
return 4;
}

View File

@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}
// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -58,12 +58,8 @@ public:
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -108,7 +104,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -116,23 +112,23 @@ private:
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.

View File

@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();

View File

@ -123,7 +123,7 @@ int dumpMLIR() {
// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

View File

@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}
// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();
// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}
// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
} // namespace
void ToyToAffineLoweringPass::runOnOperation() {
FuncOp function = getOperation();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

View File

@ -58,12 +58,8 @@ public:
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -108,7 +104,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -116,23 +112,23 @@ private:
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.

View File

@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();

View File

@ -130,17 +130,18 @@ int dumpMLIR() {
// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
}
if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());
// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

View File

@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}
// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();
// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}
// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
} // namespace
void ToyToAffineLoweringPass::runOnOperation() {
auto function = getOperation();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

View File

@ -58,12 +58,8 @@ public:
// add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &f : moduleAST) {
auto func = mlirGen(f);
if (!func)
return nullptr;
theModule.push_back(func);
}
for (FunctionAST &f : moduleAST)
mlirGen(f);
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
@ -108,7 +104,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -116,23 +112,23 @@ private:
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{}));
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.

View File

@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();

View File

@ -146,17 +146,18 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());
}
if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());
// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

View File

@ -14,9 +14,11 @@
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
#define MLIR_TUTORIAL_TOY_DIALECT_H_
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "toy/ShapeInferenceInterface.h"

View File

@ -13,6 +13,8 @@
#ifndef TOY_OPS
#define TOY_OPS
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -153,6 +155,62 @@ def CastOp : Toy_Op<"cast", [
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def FuncOp : Toy_Op<"func", [
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
IsolatedFromAbove, Symbol
]> {
let summary = "user defined function operation";
let description = [{
The "toy.func" operation represents a user defined function. These are
callable SSA-region operations that contain toy computations.
Example:
```mlir
toy.func @main() {
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
toy.print %1 : tensor<2x2xf64>
toy.return
}
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$type
);
let regions = (region AnyRegion:$body);
let builders = [OpBuilder<(ins
"StringRef":$name, "FunctionType":$type,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -49,6 +50,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
return true;
}
// All functions within toy can be inlined.
bool isLegalToInline(Region *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
@ -284,6 +291,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef name, mlir::FunctionType type,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// FunctionOpInterface provides a convenient `build` method that will populate
// the state of our FuncOp, and create an entry block.
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
// Dispatch to the FunctionOpInterface provided utility method that parses the
// function operation.
auto buildFuncType =
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
llvm::ArrayRef<mlir::Type> results,
mlir::function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return mlir::function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(mlir::OpAsmPrinter &p) {
// Dispatch to the FunctionOpInterface provided utility method that prints the
// function operation.
mlir::function_interface_impl::printFunctionOp(p, *this,
/*isVariadic=*/false);
}
/// Returns the region on the function operation that is callable.
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
return getType().getResults();
}
//===----------------------------------------------------------------------===//
// GenericCallOp
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// We only lower the main function as we expect that all other functions
// have been inlined.
if (op.getName() != "main")
return failure();
// Verify that the given main has no inputs and results.
if (op.getNumArguments() || op.getType().getNumResults()) {
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
diag << "expected 'main' to have 0 inputs and 0 results";
});
}
// Create a new non-toy function, with the same region.
auto func =
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Print operations
//===----------------------------------------------------------------------===//
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
}
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
} // namespace
void ToyToAffineLoweringPass::runOnOperation() {
auto function = getOperation();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
target
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
func::FuncDialect, memref::MemRefDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
&getContext());

View File

@ -60,11 +60,9 @@ public:
for (auto &record : moduleAST) {
if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
auto func = mlirGen(*funcAST);
mlir::toy::FuncOp func = mlirGen(*funcAST);
if (!func)
return nullptr;
theModule.push_back(func);
functionMap.insert({func.getName(), func});
} else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
if (failed(mlirGen(*str)))
@ -105,7 +103,7 @@ private:
std::pair<mlir::Value, VarDeclExprAST *>>;
/// A mapping for the functions that have been code generated to MLIR.
llvm::StringMap<mlir::FuncOp> functionMap;
llvm::StringMap<mlir::toy::FuncOp> functionMap;
/// A mapping for named struct types to the underlying MLIR type and the
/// original AST node.
@ -157,7 +155,7 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
@ -170,23 +168,23 @@ private:
argTypes.push_back(type);
}
auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), funcType);
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
funcType);
}
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
SymbolTableScopeT varScope(symbolTable);
// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
builder.setInsertionPointToEnd(theModule.getBody());
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
mlir::Block &entryBlock = function.front();
auto protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
@ -519,7 +517,7 @@ private:
emitError(location) << "no defined function found for '" << callee << "'";
return nullptr;
}
mlir::FuncOp calledFunc = calledFuncIt->second;
mlir::toy::FuncOp calledFunc = calledFuncIt->second;
return builder.create<GenericCallOp>(
location, calledFunc.getType().getResult(0),
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);

View File

@ -45,7 +45,7 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
public:
void runOnOperation() override {
auto f = getOperation();

View File

@ -146,7 +146,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
// Now that there is only one function, we can infer the shapes of each of
// the operations.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::toy::createShapeInferencePass());
optPM.addPass(mlir::createCanonicalizerPass());
@ -154,10 +154,11 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
}
if (isLoweringToAffine) {
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
// Partially lower the toy dialect.
pm.addPass(mlir::toy::createLowerToAffinePass());
// Partially lower the toy dialect with a few cleanups afterwards.
optPM.addPass(mlir::toy::createLowerToAffinePass());
// Add a few cleanups post lowering.
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
optPM.addPass(mlir::createCanonicalizerPass());
optPM.addPass(mlir::createCSEPass());

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -5,7 +5,7 @@ def main() {
print(a);
}
# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -5,7 +5,7 @@ def main() {
print(a);
}
# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>

View File

@ -11,11 +11,11 @@ def main() {
print(b);
}
# CHECK-LABEL: func @transpose_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-LABEL: toy.func @transpose_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_0]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_2:%.*]] = toy.generic_call @transpose_transpose([[VAL_1]]) : (tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: toy.print [[VAL_2]] : tensor<*xf64>

View File

@ -7,7 +7,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -5,7 +5,7 @@ def main() {
print(a);
}
# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>

View File

@ -2,13 +2,13 @@
// Check the result of inlining+shape inference on an input module.
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@ -19,10 +19,10 @@ func @main() {
toy.return
}
// CHECK-NOT: func private @multiply_transpose
// CHECK-NOT: toy.func private @multiply_transpose
// CHECK-NOT: tensor<*xf64>
// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>

View File

@ -11,7 +11,7 @@ def main() {
print(b);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return

View File

@ -7,7 +7,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]

View File

@ -1,7 +1,7 @@
// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -2,13 +2,13 @@
// Check the result of inlining+shape inference on an input module.
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@ -19,10 +19,10 @@ func @main() {
toy.return
}
// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: toy.func @multiply_transpose
// CHECK-NOT: tensor<*xf64>
// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>

View File

@ -11,7 +11,7 @@ def main() {
print(b);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return

View File

@ -7,7 +7,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]

View File

@ -1,7 +1,7 @@
// RUN: toyc-ch6 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch6 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -1,6 +1,6 @@
// RUN: toyc-ch6 %s -emit=llvm -opt
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -5,7 +5,7 @@ def main() {
print(a);
}
# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>

View File

@ -2,13 +2,13 @@
// Check the result of inlining+shape inference on an input module.
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
@ -19,10 +19,10 @@ func @main() {
toy.return
}
// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: toy.func @multiply_transpose
// CHECK-NOT: tensor<*xf64>
// CHECK-LABEL: func @main()
// CHECK-LABEL: toy.func @main()
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>

View File

@ -11,7 +11,7 @@ def main() {
print(b);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return

View File

@ -7,7 +7,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]

View File

@ -1,7 +1,7 @@
// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -13,14 +13,14 @@ def main() {
print(d);
}
# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -4,6 +4,6 @@
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
func @main() {
toy.func @main() {
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -1,6 +1,6 @@
// RUN: toyc-ch7 %s -emit=llvm -opt
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
%3 = toy.mul %2, %2 : tensor<3x2xf64>

View File

@ -5,7 +5,7 @@ def main() {
print(a);
}
# CHECK-LABEL: func @main() {
# CHECK-LABEL: toy.func @main() {
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>

View File

@ -2,13 +2,13 @@
// Check the result of inlining+shape inference on an input module.
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
%2 = toy.mul %0, %1 : tensor<*xf64>
toy.return %2 : tensor<*xf64>
}
func @main() {
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>

View File

@ -21,7 +21,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func private @multiply_transpose(
# CHECK-LABEL: toy.func private @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_1:%.*]] = toy.struct_access [[VAL_0]][0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
# CHECK-NEXT: [[VAL_2:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
@ -30,13 +30,13 @@ def main() {
# CHECK-NEXT: [[VAL_5:%.*]] = toy.mul [[VAL_2]], [[VAL_4]] : tensor<*xf64>
# CHECK-NEXT: toy.return [[VAL_5]] : tensor<*xf64>
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_6:%.*]] = toy.struct_constant [dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
# CHECK-NEXT: [[VAL_7:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]]) : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
# CHECK-NEXT: toy.print [[VAL_7]] : tensor<*xf64>
# CHECK-NEXT: toy.return
# OPT-LABEL: func @main()
# OPT-LABEL: toy.func @main()
# OPT-NEXT: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# OPT-NEXT: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
# OPT-NEXT: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>

View File

@ -1,6 +1,6 @@
// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s
func @main() {
toy.func @main() {
%0 = toy.struct_constant [
[dense<4.000000e+00> : tensor<2x2xf64>], dense<4.000000e+00> : tensor<2x2xf64>
] : !toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>>
@ -10,6 +10,6 @@ func @main() {
toy.return
}
// CHECK-LABEL: func @main
// CHECK-LABEL: toy.func @main
// CHECK-NEXT: %[[CST:.*]] = toy.constant dense<4.0
// CHECK-NEXT: toy.print %[[CST]]

View File

@ -11,7 +11,7 @@ def main() {
print(b);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
# CHECK-NEXT: toy.return

View File

@ -7,7 +7,7 @@ def main() {
print(c);
}
# CHECK-LABEL: func @main()
# CHECK-LABEL: toy.func @main()
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
# CHECK-SAME: dense<[
# CHECK-SAME: [1.000000e+00], [2.000000e+00]