[mlir] Provide minimal Python bindings for the math dialect

Reviewed By: ulysseB

Differential Revision: https://reviews.llvm.org/D104045
This commit is contained in:
Alex Zinenko 2021-06-10 19:00:34 +02:00
parent c1bb75febe
commit ad381e39a5
4 changed files with 51 additions and 0 deletions

View File

@ -25,6 +25,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
DEPENDS LinalgOdsGen)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps)
add_mlir_dialect_python_bindings(MLIRBindingsPythonMathOps
TD_FILE MathOps.td
DIALECT_NAME math)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMathOps)
add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps
TD_FILE MemRefOps.td
DIALECT_NAME memref)

View File

@ -0,0 +1,15 @@
//===-- MathOps.td - Entry point for MathOps bindings ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_MATH_OPS
#define PYTHON_BINDINGS_MATH_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Math/IR/MathOps.td"
#endif

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._math_ops_gen import *

View File

@ -0,0 +1,26 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.math as mlir_math
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testMathOps
@run
def testMathOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(F32Type.get())
def emit_sqrt(arg):
return mlir_math.SqrtOp(F32Type.get(), arg)
# CHECK-LABEL: func @emit_sqrt(
# CHECK-SAME: %[[ARG:.*]]: f32) {
# CHECK: math.sqrt %[[ARG]] : f32
# CHECK: return
# CHECK: }
print(module)