From 4cd1b66dffb06695b4eaf725df8c402347e39bf0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 20 May 2021 17:51:53 +0900 Subject: [PATCH] [mlir] Add Python bindings for vector dialect Also add a minimal test case for vector.print. Differential Revision: https://reviews.llvm.org/D102826 --- mlir/python/mlir/dialects/CMakeLists.txt | 5 +++++ mlir/python/mlir/dialects/VectorOps.td | 15 ++++++++++++++ mlir/python/mlir/dialects/vector.py | 5 +++++ mlir/test/python/dialects/vector.py | 26 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 mlir/python/mlir/dialects/VectorOps.td create mode 100644 mlir/python/mlir/dialects/vector.py create mode 100644 mlir/test/python/dialects/vector.py diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt index 31a4ee55b9d3..cad3bb7100e2 100644 --- a/mlir/python/mlir/dialects/CMakeLists.txt +++ b/mlir/python/mlir/dialects/CMakeLists.txt @@ -45,6 +45,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps DIALECT_NAME tensor) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonVectorOps + TD_FILE VectorOps.td + DIALECT_NAME vector) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonVectorOps) + ################################################################################ # Installation. ################################################################################ diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td new file mode 100644 index 000000000000..b06668bdf4c5 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -0,0 +1,15 @@ +//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR_OPS +#define PYTHON_BINDINGS_VECTOR_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Vector/VectorOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py new file mode 100644 index 000000000000..610c0b204c6b --- /dev/null +++ b/mlir/python/mlir/dialects/vector.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._vector_ops_gen import * diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py new file mode 100644 index 000000000000..4d7052859e7d --- /dev/null +++ b/mlir/test/python/dialects/vector.py @@ -0,0 +1,26 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +import mlir.dialects.builtin as builtin +import mlir.dialects.vector as vector + +def run(f): + print("\nTEST:", f.__name__) + f() + +# CHECK-LABEL: TEST: testPrintOp +@run +def testPrintOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) + def print_vector(arg): + return vector.PrintOp(arg) + + # CHECK-LABEL: func @print_vector( + # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { + # CHECK: vector.print %[[ARG]] : vector<12x5xf32> + # CHECK: return + # CHECK: } + print(module)