forked from OSchip/llvm-project
[mlir][toy] Define a FuncOp operation in toy and drop the dependence on FuncOp
FuncOp is being moved out of the builtin dialect, and defining a custom toy operation showcases various aspects of defining function-like operation (e.g. inlining, passes, etc.). Differential Revision: https://reviews.llvm.org/D121264
This commit is contained in:
parent
f96a8675cd
commit
ee2c6cd906
|
@ -120,7 +120,8 @@ types, to be customized. At the same time, IR elements can always be reduced to
|
|||
the above fundamental concepts. This allows MLIR to parse, represent, and
|
||||
[round-trip](../../../getting_started/Glossary.md/#round-trip) IR for *any*
|
||||
operation. For example, we could place our Toy operation from above into an
|
||||
`.mlir` file and round-trip through *mlir-opt* without registering any dialect:
|
||||
`.mlir` file and round-trip through *mlir-opt* without registering any `toy`
|
||||
related dialect:
|
||||
|
||||
```mlir
|
||||
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
|
||||
|
@ -558,13 +559,14 @@ Results in the following IR:
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
"toy.func"() ({
|
||||
^bb0(%arg0: tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":4:1), %arg1: tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":4:1)):
|
||||
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:10)
|
||||
%1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
|
||||
%2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
|
||||
"toy.return"(%2) : (tensor<*xf64>) -> () loc("test/Examples/Toy/Ch2/codegen.toy":5:3)
|
||||
} loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
|
||||
func @main() {
|
||||
}) {sym_name = "multiply_transpose", type = (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>} : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
|
||||
"toy.func"() ({
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:17)
|
||||
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:3)
|
||||
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> loc("test/Examples/Toy/Ch2/codegen.toy":10:17)
|
||||
|
@ -573,7 +575,7 @@ module {
|
|||
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":12:11)
|
||||
"toy.print"(%5) : (tensor<*xf64>) -> () loc("test/Examples/Toy/Ch2/codegen.toy":13:3)
|
||||
"toy.return"() : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
|
||||
} loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
|
||||
}) {sym_name = "main", type = () -> ()} : () -> () loc("test/Examples/Toy/Ch2/codegen.toy":8:1)
|
||||
} loc(unknown)
|
||||
```
|
||||
|
||||
|
@ -686,13 +688,13 @@ now get a much more readable:
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:10)
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:25)
|
||||
toy.return %2 : tensor<*xf64> loc("test/Examples/Toy/Ch2/codegen.toy":5:3)
|
||||
} loc("test/Examples/Toy/Ch2/codegen.toy":4:1)
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:17)
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> loc("test/Examples/Toy/Ch2/codegen.toy":9:3)
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> loc("test/Examples/Toy/Ch2/codegen.toy":10:17)
|
||||
|
|
|
@ -37,7 +37,7 @@ def transpose_transpose(x) {
|
|||
Which corresponds to the following IR:
|
||||
|
||||
```mlir
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
|
||||
toy.return %1 : tensor<*xf64>
|
||||
|
@ -125,14 +125,14 @@ similar way to LLVM:
|
|||
|
||||
```c++
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
|
||||
```
|
||||
|
||||
Finally, we can run `toyc-ch3 test/Examples/Toy/Ch3/transpose_transpose.toy
|
||||
-emit=mlir -opt` and observe our pattern in action:
|
||||
|
||||
```mlir
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
toy.return %arg0 : tensor<*xf64>
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...}
|
|||
Let's retry now `toyc-ch3 test/transpose_transpose.toy -emit=mlir -opt`:
|
||||
|
||||
```mlir
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.return %arg0 : tensor<*xf64>
|
||||
}
|
||||
```
|
||||
|
@ -228,7 +228,7 @@ def main() {
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2xf64>) to tensor<2x1xf64>
|
||||
%2 = toy.reshape(%1 : tensor<2x1xf64>) to tensor<2x1xf64>
|
||||
|
@ -244,7 +244,7 @@ We can try to run `toyc-ch3 test/Examples/Toy/Ch3/trivial_reshape.toy -emit=mlir
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf64>
|
||||
toy.print %0 : tensor<2x1xf64>
|
||||
toy.return
|
||||
|
|
|
@ -77,6 +77,14 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// This hook cheks if the given 'src' region can be inlined into the 'dest'
|
||||
/// region. The regions here are the bodies of the callable functions. For
|
||||
/// Toy, any function can be inlined, so we simply return true.
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// This hook is called when a terminator operation has been inlined. The only
|
||||
/// terminator that we have in the Toy dialect is the return
|
||||
/// operation(toy.return). We handle the return by replacing the values
|
||||
|
@ -101,7 +109,7 @@ main function) in the MLIR generator.
|
|||
|
||||
```c++
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
...
|
||||
// If this function isn't main, then set the visibility to private.
|
||||
if (funcAST.getProto()->getName() != "main")
|
||||
|
@ -121,12 +129,12 @@ void ToyDialect::initialize() {
|
|||
```
|
||||
|
||||
Next, we need to provide a way for the inliner to know that `toy.generic_call`
|
||||
represents a call to a function. MLIR provides an
|
||||
[operation interface](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
|
||||
to mark an operation as being "call-like". Unlike dialect interfaces, operation
|
||||
interfaces provide a more refined granularity of information that is specific
|
||||
and core to a single operation. The interface that we will be adding here is the
|
||||
`CallOpInterface`.
|
||||
represents a call, and `toy.func` represents a function. MLIR provides
|
||||
[operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
|
||||
to mark an operation as being "call-like" or "callable-like". Unlike dialect interfaces,
|
||||
operation interfaces provide a more refined granularity of information that is specific
|
||||
and core to a single operation. The interfaces that we will be adding here is the
|
||||
`CallOpInterface` and `CallableOpInterface`.
|
||||
|
||||
To add this interface we just need to include the definition into our operation
|
||||
specification file (`Ops.td`):
|
||||
|
@ -138,6 +146,11 @@ include "mlir/Interfaces/CallInterfaces.td"
|
|||
and add it to the traits list of `GenericCallOp`:
|
||||
|
||||
```tablegen
|
||||
def FuncOp : Toy_Op<"func",
|
||||
[DeclareOpInterfaceMethods<CallableOpInterface>]> {
|
||||
...
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
[DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
...
|
||||
|
@ -149,6 +162,15 @@ auto-declare all of the interface methods in the class declaration of
|
|||
GenericCallOp. This means that we just need to provide a definition:
|
||||
|
||||
```c++
|
||||
/// Returns the region on the function operation that is callable.
|
||||
Region *FuncOp::getCallableRegion() { return &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
ArrayRef<Type> FuncOp::getCallableResults() { return getType().getResults(); }
|
||||
|
||||
// ....
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
/// call interface.
|
||||
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
|
@ -170,13 +192,13 @@ inliner pass to the pass manager for Toy:
|
|||
Now let's look at a working example:
|
||||
|
||||
```mlir
|
||||
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64>
|
||||
toy.return %2 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
@ -214,6 +236,7 @@ def CastOp : Toy_Op<"cast", [
|
|||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -263,14 +286,14 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
If we run the working example through the pipeline again, we get the expected:
|
||||
|
||||
```mlir
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%1 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%2 = "toy.cast"(%1) : (tensor<2x3xf64>) -> tensor<*xf64>
|
||||
%3 = "toy.cast"(%0) : (tensor<2x3xf64>) -> tensor<*xf64>
|
||||
%4 = "toy.transpose"(%2) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%5 = "toy.transpose"(%3) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%6 = "toy.mul"(%4, %5) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64>
|
||||
%3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64>
|
||||
%4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
|
||||
%5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
|
||||
%6 = toy.mul %4, %5 : tensor<*xf64>
|
||||
toy.print %6 : tensor<*xf64>
|
||||
toy.return
|
||||
}
|
||||
|
@ -357,10 +380,10 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
|||
|
||||
At this point, each of the necessary Toy operations provide a mechanism by which
|
||||
to infer their output shapes. The ShapeInferencePass will operate on functions:
|
||||
it will run on each Function in isolation. MLIR also supports general
|
||||
it will run on each function in isolation. MLIR also supports general
|
||||
[OperationPasses](../../PassManagement.md#operation-pass) that run on any
|
||||
isolated operation (i.e. other function-like operations), but here our module
|
||||
only contains functions, so there is no need to generalize to all operations.
|
||||
isolated operation, but here our module only contains functions, so there is no
|
||||
need to generalize to all operations.
|
||||
|
||||
Implementing such a pass is done by creating a class inheriting from
|
||||
`mlir::OperationPass<FuncOp>` and overriding the `runOnOperation()` method.
|
||||
|
@ -421,10 +444,10 @@ We can then add our pass to the pass manager:
|
|||
If we rerun our original example, we now get the following:
|
||||
|
||||
```mlir
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%1 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
|
||||
%2 = "toy.mul"(%1, %1) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%2 = toy.mul %1, %1 : tensor<3x2xf64>
|
||||
toy.print %2 : tensor<3x2xf64>
|
||||
toy.return
|
||||
}
|
||||
|
|
|
@ -172,8 +172,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our *illegal*
|
||||
// operations were not converted successfully.
|
||||
mlir::FuncOp function = getOperation();
|
||||
if (mlir::failed(mlir::applyPartialConversion(function, target, patterns)))
|
||||
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
```
|
||||
|
@ -232,7 +231,7 @@ def PrintOp : Toy_Op<"print"> {
|
|||
Let's take a concrete example:
|
||||
|
||||
```mlir
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -119,7 +119,7 @@ that only legal operations will remain after the conversion.
|
|||
Looking back at our current working example:
|
||||
|
||||
```mlir
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -327,7 +327,7 @@ Which generates the following:
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
|
||||
toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
|
||||
toy.return
|
||||
}
|
||||
}
|
||||
|
@ -405,7 +405,7 @@ and finally get a full MLIR module:
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
|
||||
toy.func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64> {
|
||||
%0 = toy.struct_access %arg0[0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
|
||||
%1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.struct_access %arg0[1] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
|
||||
|
@ -413,7 +413,7 @@ module {
|
|||
%4 = toy.mul %1, %3 : tensor<*xf64>
|
||||
toy.return %4 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.struct_constant [
|
||||
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
|
||||
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
|
@ -434,7 +434,7 @@ After inlining, the MLIR module in the previous section looks something like:
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.struct_constant [
|
||||
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>,
|
||||
dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
|
@ -500,7 +500,7 @@ changes to our pipeline.
|
|||
|
||||
```mlir
|
||||
module {
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%2 = toy.mul %1, %1 : tensor<3x2xf64>
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
/// Include the auto-generated header file containing the declaration of the toy
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
// Provide a definition of the 'toy' dialect in the ODS framework so that we
|
||||
|
@ -106,6 +108,61 @@ def AddOp : Toy_Op<"add"> {
|
|||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
FunctionOpInterface, IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -187,6 +188,39 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -58,12 +58,8 @@ public:
|
|||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &f : moduleAST) {
|
||||
auto func = mlirGen(f);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule.push_back(func);
|
||||
}
|
||||
for (FunctionAST &f : moduleAST)
|
||||
mlirGen(f);
|
||||
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -116,23 +112,23 @@ private:
|
|||
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
/// Include the auto-generated header file containing the declaration of the toy
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef TOY_OPS
|
||||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
// Provide a definition of the 'toy' dialect in the ODS framework so that we
|
||||
|
@ -105,6 +107,61 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
|||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
FunctionOpInterface, IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -174,6 +175,39 @@ mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
|
|||
|
||||
void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -58,12 +58,8 @@ public:
|
|||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &f : moduleAST) {
|
||||
auto func = mlirGen(f);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule.push_back(func);
|
||||
}
|
||||
for (FunctionAST &f : moduleAST)
|
||||
mlirGen(f);
|
||||
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -116,23 +112,23 @@ private:
|
|||
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
|
|
@ -118,7 +118,7 @@ int dumpMLIR() {
|
|||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Add a run of the canonicalizer to optimize the mlir module.
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
|
||||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
}
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef TOY_OPS
|
||||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
|
|||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
|
||||
IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
|
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
// All functions within toy can be inlined.
|
||||
bool isLegalToInline(Region *, Region *, bool,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
/// Returns the region on the function operation that is callable.
|
||||
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -58,12 +58,8 @@ public:
|
|||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &f : moduleAST) {
|
||||
auto func = mlirGen(f);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule.push_back(func);
|
||||
}
|
||||
for (FunctionAST &f : moduleAST)
|
||||
mlirGen(f);
|
||||
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -116,23 +112,23 @@ private:
|
|||
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
|
|
@ -45,7 +45,7 @@ namespace {
|
|||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
class ShapeInferencePass
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
auto f = getOperation();
|
||||
|
|
|
@ -123,7 +123,7 @@ int dumpMLIR() {
|
|||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
|
||||
optPM.addPass(mlir::toy::createShapeInferencePass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef TOY_OPS
|
||||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
|
|||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
|
||||
IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
|
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
// All functions within toy can be inlined.
|
||||
bool isLegalToInline(Region *, Region *, bool,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
/// Returns the region on the function operation that is callable.
|
||||
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
|
||||
|
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Func operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
|
||||
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We only lower the main function as we expect that all other functions
|
||||
// have been inlined.
|
||||
if (op.getName() != "main")
|
||||
return failure();
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (op.getNumArguments() || op.getType().getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
|
||||
diag << "expected 'main' to have 0 inputs and 0 results";
|
||||
});
|
||||
}
|
||||
|
||||
// Create a new non-toy function, with the same region.
|
||||
auto func =
|
||||
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
|
||||
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Print operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
/// rest of the code in the Toy dialect.
|
||||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
|
|||
} // namespace
|
||||
|
||||
void ToyToAffineLoweringPass::runOnOperation() {
|
||||
FuncOp function = getOperation();
|
||||
|
||||
// We only lower the main function as we expect that all other functions have
|
||||
// been inlined.
|
||||
if (function.getName() != "main")
|
||||
return;
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (function.getNumArguments() || function.getType().getNumResults()) {
|
||||
function.emitError("expected 'main' to have 0 inputs and 0 results");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
|
||||
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
target
|
||||
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
|
||||
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
|
||||
&getContext());
|
||||
|
||||
|
|
|
@ -58,12 +58,8 @@ public:
|
|||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &f : moduleAST) {
|
||||
auto func = mlirGen(f);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule.push_back(func);
|
||||
}
|
||||
for (FunctionAST &f : moduleAST)
|
||||
mlirGen(f);
|
||||
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -116,23 +112,23 @@ private:
|
|||
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
|
|
@ -45,7 +45,7 @@ namespace {
|
|||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
class ShapeInferencePass
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
auto f = getOperation();
|
||||
|
|
|
@ -130,17 +130,18 @@ int dumpMLIR() {
|
|||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
|
||||
optPM.addPass(mlir::toy::createShapeInferencePass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
}
|
||||
|
||||
if (isLoweringToAffine) {
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
// Partially lower the toy dialect.
|
||||
pm.addPass(mlir::toy::createLowerToAffinePass());
|
||||
|
||||
// Partially lower the toy dialect with a few cleanups afterwards.
|
||||
optPM.addPass(mlir::toy::createLowerToAffinePass());
|
||||
// Add a few cleanups post lowering.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef TOY_OPS
|
||||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -134,6 +136,62 @@ def CastOp : Toy_Op<"cast", [
|
|||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
|
||||
IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
|
@ -48,6 +49,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
// All functions within toy can be inlined.
|
||||
bool isLegalToInline(Region *, Region *, bool,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -257,6 +264,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
/// Returns the region on the function operation that is callable.
|
||||
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
|
||||
|
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Func operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
|
||||
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We only lower the main function as we expect that all other functions
|
||||
// have been inlined.
|
||||
if (op.getName() != "main")
|
||||
return failure();
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (op.getNumArguments() || op.getType().getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
|
||||
diag << "expected 'main' to have 0 inputs and 0 results";
|
||||
});
|
||||
}
|
||||
|
||||
// Create a new non-toy function, with the same region.
|
||||
auto func =
|
||||
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
|
||||
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Print operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
/// rest of the code in the Toy dialect.
|
||||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
|
|||
} // namespace
|
||||
|
||||
void ToyToAffineLoweringPass::runOnOperation() {
|
||||
auto function = getOperation();
|
||||
|
||||
// We only lower the main function as we expect that all other functions have
|
||||
// been inlined.
|
||||
if (function.getName() != "main")
|
||||
return;
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (function.getNumArguments() || function.getType().getNumResults()) {
|
||||
function.emitError("expected 'main' to have 0 inputs and 0 results");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
|
||||
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
target
|
||||
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
|
||||
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
|
||||
&getContext());
|
||||
|
||||
|
|
|
@ -58,12 +58,8 @@ public:
|
|||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &f : moduleAST) {
|
||||
auto func = mlirGen(f);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule.push_back(func);
|
||||
}
|
||||
for (FunctionAST &f : moduleAST)
|
||||
mlirGen(f);
|
||||
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
|
@ -108,7 +104,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -116,23 +112,23 @@ private:
|
|||
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
|
|
@ -45,7 +45,7 @@ namespace {
|
|||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
class ShapeInferencePass
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
auto f = getOperation();
|
||||
|
|
|
@ -146,17 +146,18 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
|
||||
optPM.addPass(mlir::toy::createShapeInferencePass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
}
|
||||
|
||||
if (isLoweringToAffine) {
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
// Partially lower the toy dialect.
|
||||
pm.addPass(mlir::toy::createLowerToAffinePass());
|
||||
|
||||
// Partially lower the toy dialect with a few cleanups afterwards.
|
||||
optPM.addPass(mlir::toy::createLowerToAffinePass());
|
||||
// Add a few cleanups post lowering.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
#define MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#ifndef TOY_OPS
|
||||
#define TOY_OPS
|
||||
|
||||
include "mlir/IR/FunctionInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -153,6 +155,62 @@ def CastOp : Toy_Op<"cast", [
|
|||
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FuncOp : Toy_Op<"func", [
|
||||
DeclareOpInterfaceMethods<CallableOpInterface>, FunctionOpInterface,
|
||||
IsolatedFromAbove, Symbol
|
||||
]> {
|
||||
let summary = "user defined function operation";
|
||||
let description = [{
|
||||
The "toy.func" operation represents a user defined function. These are
|
||||
callable SSA-region operations that contain toy computations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
%1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
toy.print %1 : tensor<2x2xf64>
|
||||
toy.return
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
TypeAttrOf<FunctionType>:$type
|
||||
);
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"StringRef":$name, "FunctionType":$type,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the type of this function.
|
||||
/// FIXME: We should drive this via the ODS `type` param.
|
||||
FunctionType getType() {
|
||||
return getTypeAttr().getValue().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// FunctionOpInterface Methods
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
|
||||
|
||||
/// Returns the result types of this function.
|
||||
ArrayRef<Type> getResultTypes() { return type().getResults(); }
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
|
@ -49,6 +50,12 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
return true;
|
||||
}
|
||||
|
||||
// All functions within toy can be inlined.
|
||||
bool isLegalToInline(Region *, Region *, bool,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -284,6 +291,48 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|||
return !input.hasRank() || !output.hasRank() || input == output;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef name, mlir::FunctionType type,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// FunctionOpInterface provides a convenient `build` method that will populate
|
||||
// the state of our FuncOp, and create an entry block.
|
||||
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
|
||||
}
|
||||
|
||||
mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that parses the
|
||||
// function operation.
|
||||
auto buildFuncType =
|
||||
[](mlir::Builder &builder, llvm::ArrayRef<mlir::Type> argTypes,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return mlir::function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false, buildFuncType);
|
||||
}
|
||||
|
||||
void FuncOp::print(mlir::OpAsmPrinter &p) {
|
||||
// Dispatch to the FunctionOpInterface provided utility method that prints the
|
||||
// function operation.
|
||||
mlir::function_interface_impl::printFunctionOp(p, *this,
|
||||
/*isVariadic=*/false);
|
||||
}
|
||||
|
||||
/// Returns the region on the function operation that is callable.
|
||||
mlir::Region *FuncOp::getCallableRegion() { return &getBody(); }
|
||||
|
||||
/// Returns the results types that the callable region produces when
|
||||
/// executed.
|
||||
llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
|
||||
return getType().getResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
|
||||
|
@ -197,6 +198,37 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Func operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
|
||||
using OpConversionPattern<toy::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We only lower the main function as we expect that all other functions
|
||||
// have been inlined.
|
||||
if (op.getName() != "main")
|
||||
return failure();
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (op.getNumArguments() || op.getType().getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
|
||||
diag << "expected 'main' to have 0 inputs and 0 results";
|
||||
});
|
||||
}
|
||||
|
||||
// Create a new non-toy function, with the same region.
|
||||
auto func =
|
||||
rewriter.create<mlir::FuncOp>(op.getLoc(), op.getName(), op.getType());
|
||||
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyToAffine RewritePatterns: Print operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -277,7 +309,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
/// rest of the code in the Toy dialect.
|
||||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<ToyToAffineLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, func::FuncDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
@ -286,19 +318,6 @@ struct ToyToAffineLoweringPass
|
|||
} // namespace
|
||||
|
||||
void ToyToAffineLoweringPass::runOnOperation() {
|
||||
auto function = getOperation();
|
||||
|
||||
// We only lower the main function as we expect that all other functions have
|
||||
// been inlined.
|
||||
if (function.getName() != "main")
|
||||
return;
|
||||
|
||||
// Verify that the given main has no inputs and results.
|
||||
if (function.getNumArguments() || function.getType().getNumResults()) {
|
||||
function.emitError("expected 'main' to have 0 inputs and 0 results");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering.
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -306,8 +325,9 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine`, `Arithmetic`, `Func`, and `MemRef` dialects.
|
||||
target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
target
|
||||
.addLegalDialect<AffineDialect, BuiltinDialect, arith::ArithmeticDialect,
|
||||
func::FuncDialect, memref::MemRefDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
@ -324,7 +344,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
|||
// Now that the conversion target has been defined, we just need to provide
|
||||
// the set of patterns that will lower the Toy operations.
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
|
||||
patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
|
||||
PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
|
||||
&getContext());
|
||||
|
||||
|
|
|
@ -60,11 +60,9 @@ public:
|
|||
|
||||
for (auto &record : moduleAST) {
|
||||
if (FunctionAST *funcAST = llvm::dyn_cast<FunctionAST>(record.get())) {
|
||||
auto func = mlirGen(*funcAST);
|
||||
mlir::toy::FuncOp func = mlirGen(*funcAST);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
|
||||
theModule.push_back(func);
|
||||
functionMap.insert({func.getName(), func});
|
||||
} else if (StructAST *str = llvm::dyn_cast<StructAST>(record.get())) {
|
||||
if (failed(mlirGen(*str)))
|
||||
|
@ -105,7 +103,7 @@ private:
|
|||
std::pair<mlir::Value, VarDeclExprAST *>>;
|
||||
|
||||
/// A mapping for the functions that have been code generated to MLIR.
|
||||
llvm::StringMap<mlir::FuncOp> functionMap;
|
||||
llvm::StringMap<mlir::toy::FuncOp> functionMap;
|
||||
|
||||
/// A mapping for named struct types to the underlying MLIR type and the
|
||||
/// original AST node.
|
||||
|
@ -157,7 +155,7 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
|
@ -170,23 +168,23 @@ private:
|
|||
argTypes.push_back(type);
|
||||
}
|
||||
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||
return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
|
||||
funcType);
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
SymbolTableScopeT varScope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||
builder.setInsertionPointToEnd(theModule.getBody());
|
||||
mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
mlir::Block &entryBlock = function.front();
|
||||
auto protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
|
@ -519,7 +517,7 @@ private:
|
|||
emitError(location) << "no defined function found for '" << callee << "'";
|
||||
return nullptr;
|
||||
}
|
||||
mlir::FuncOp calledFunc = calledFuncIt->second;
|
||||
mlir::toy::FuncOp calledFunc = calledFuncIt->second;
|
||||
return builder.create<GenericCallOp>(
|
||||
location, calledFunc.getType().getResult(0),
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
|
||||
|
|
|
@ -45,7 +45,7 @@ namespace {
|
|||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
class ShapeInferencePass
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
|
||||
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
auto f = getOperation();
|
||||
|
|
|
@ -146,7 +146,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::toy::createShapeInferencePass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
|
@ -154,10 +154,11 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|||
}
|
||||
|
||||
if (isLoweringToAffine) {
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
// Partially lower the toy dialect.
|
||||
pm.addPass(mlir::toy::createLowerToAffinePass());
|
||||
|
||||
// Partially lower the toy dialect with a few cleanups afterwards.
|
||||
optPM.addPass(mlir::toy::createLowerToAffinePass());
|
||||
// Add a few cleanups post lowering.
|
||||
mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
optPM.addPass(mlir::createCSEPass());
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-LABEL: toy.func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ def main() {
|
|||
print(a);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-LABEL: toy.func @main() {
|
||||
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-LABEL: toy.func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ def main() {
|
|||
print(a);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-LABEL: toy.func @main() {
|
||||
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
|
||||
|
|
|
@ -11,11 +11,11 @@ def main() {
|
|||
print(b);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @transpose_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-LABEL: toy.func @transpose_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_0]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_2:%.*]] = toy.generic_call @transpose_transpose([[VAL_1]]) : (tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_2]] : tensor<*xf64>
|
||||
|
|
|
@ -7,7 +7,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
|
||||
# CHECK-SAME: dense<[
|
||||
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func private @multiply_transpose(
|
||||
# CHECK-LABEL: toy.func private @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ def main() {
|
|||
print(a);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-LABEL: toy.func @main() {
|
||||
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64>
|
||||
toy.return %2 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
@ -19,10 +19,10 @@ func @main() {
|
|||
toy.return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func private @multiply_transpose
|
||||
// CHECK-NOT: toy.func private @multiply_transpose
|
||||
// CHECK-NOT: tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: func @main()
|
||||
// CHECK-LABEL: toy.func @main()
|
||||
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
|
||||
|
|
|
@ -11,7 +11,7 @@ def main() {
|
|||
print(b);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.return
|
|
@ -7,7 +7,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
|
||||
# CHECK-SAME: dense<[
|
||||
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s
|
||||
// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func private @multiply_transpose(
|
||||
# CHECK-LABEL: toy.func private @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64>
|
||||
toy.return %2 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
@ -19,10 +19,10 @@ func @main() {
|
|||
toy.return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @multiply_transpose
|
||||
// CHECK-NOT: toy.func @multiply_transpose
|
||||
// CHECK-NOT: tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: func @main()
|
||||
// CHECK-LABEL: toy.func @main()
|
||||
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
|
||||
|
|
|
@ -11,7 +11,7 @@ def main() {
|
|||
print(b);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.return
|
|
@ -7,7 +7,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
|
||||
# CHECK-SAME: dense<[
|
||||
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: toyc-ch6 %s -emit=mlir-affine 2>&1 | FileCheck %s
|
||||
// RUN: toyc-ch6 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func private @multiply_transpose(
|
||||
# CHECK-LABEL: toy.func private @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: toyc-ch6 %s -emit=llvm -opt
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -5,7 +5,7 @@ def main() {
|
|||
print(a);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-LABEL: toy.func @main() {
|
||||
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64>
|
||||
toy.return %2 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
@ -19,10 +19,10 @@ func @main() {
|
|||
toy.return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @multiply_transpose
|
||||
// CHECK-NOT: toy.func @multiply_transpose
|
||||
// CHECK-NOT: tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: func @main()
|
||||
// CHECK-LABEL: toy.func @main()
|
||||
// CHECK: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
// CHECK: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
// CHECK: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
|
||||
|
|
|
@ -11,7 +11,7 @@ def main() {
|
|||
print(b);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.return
|
|
@ -7,7 +7,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
|
||||
# CHECK-SAME: dense<[
|
||||
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: toyc-ch7 %s -emit=mlir-affine 2>&1 | FileCheck %s
|
||||
// RUN: toyc-ch7 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -13,14 +13,14 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func private @multiply_transpose(
|
||||
# CHECK-LABEL: toy.func private @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK: [[VAL_2:%.*]] = toy.transpose([[VAL_0]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = toy.mul [[VAL_2]], [[VAL_3]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_4]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.reshape([[VAL_5]] : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -4,6 +4,6 @@
|
|||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: toyc-ch7 %s -emit=llvm -opt
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
%3 = toy.mul %2, %2 : tensor<3x2xf64>
|
||||
|
|
|
@ -5,7 +5,7 @@ def main() {
|
|||
print(a);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-LABEL: toy.func @main() {
|
||||
# CHECK-NEXT: %0 = toy.constant dense<5.500000e+00> : tensor<f64>
|
||||
# CHECK-NEXT: %1 = toy.reshape(%0 : tensor<f64>) to tensor<2x2xf64>
|
||||
# CHECK-NEXT: toy.print %1 : tensor<2x2xf64>
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
toy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
|
||||
%1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
|
||||
%2 = toy.mul %0, %1 : tensor<*xf64>
|
||||
toy.return %2 : tensor<*xf64>
|
||||
}
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
%1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
|
||||
%2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
|
||||
|
|
|
@ -21,7 +21,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func private @multiply_transpose(
|
||||
# CHECK-LABEL: toy.func private @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: !toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.struct_access [[VAL_0]][0] : !toy.struct<tensor<*xf64>, tensor<*xf64>> -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_2:%.*]] = toy.transpose([[VAL_1]] : tensor<*xf64>) to tensor<*xf64>
|
||||
|
@ -30,13 +30,13 @@ def main() {
|
|||
# CHECK-NEXT: [[VAL_5:%.*]] = toy.mul [[VAL_2]], [[VAL_4]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return [[VAL_5]] : tensor<*xf64>
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = toy.struct_constant [dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>, dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>] : !toy.struct<tensor<*xf64>, tensor<*xf64>>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = toy.generic_call @multiply_transpose([[VAL_6]]) : (!toy.struct<tensor<*xf64>, tensor<*xf64>>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_7]] : tensor<*xf64>
|
||||
# CHECK-NEXT: toy.return
|
||||
|
||||
# OPT-LABEL: func @main()
|
||||
# OPT-LABEL: toy.func @main()
|
||||
# OPT-NEXT: [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# OPT-NEXT: [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
|
||||
# OPT-NEXT: [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
toy.func @main() {
|
||||
%0 = toy.struct_constant [
|
||||
[dense<4.000000e+00> : tensor<2x2xf64>], dense<4.000000e+00> : tensor<2x2xf64>
|
||||
] : !toy.struct<!toy.struct<tensor<*xf64>>, tensor<*xf64>>
|
||||
|
@ -10,6 +10,6 @@ func @main() {
|
|||
toy.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-LABEL: toy.func @main
|
||||
// CHECK-NEXT: %[[CST:.*]] = toy.constant dense<4.0
|
||||
// CHECK-NEXT: toy.print %[[CST]]
|
||||
|
|
|
@ -11,7 +11,7 @@ def main() {
|
|||
print(b);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_1:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.print [[VAL_1]] : tensor<2x3xf64>
|
||||
# CHECK-NEXT: toy.return
|
|
@ -7,7 +7,7 @@ def main() {
|
|||
print(c);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main()
|
||||
# CHECK-LABEL: toy.func @main()
|
||||
# CHECK-NEXT: [[VAL_0:%.*]] = toy.constant
|
||||
# CHECK-SAME: dense<[
|
||||
# CHECK-SAME: [1.000000e+00], [2.000000e+00]
|
||||
|
|
Loading…
Reference in New Issue