[mlir] Add Python bindings for vector dialect

Also add a minimal test case for vector.print.

Differential Revision: https://reviews.llvm.org/D102826
This commit is contained in:
Matthias Springer 2021-05-20 17:51:53 +09:00
parent 412a3381f7
commit 4cd1b66dff
4 changed files with 51 additions and 0 deletions

View File

@ -45,6 +45,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps
DIALECT_NAME tensor)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps)
add_mlir_dialect_python_bindings(MLIRBindingsPythonVectorOps
TD_FILE VectorOps.td
DIALECT_NAME vector)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonVectorOps)
################################################################################
# Installation.
################################################################################

View File

@ -0,0 +1,15 @@
//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_VECTOR_OPS
#define PYTHON_BINDINGS_VECTOR_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Vector/VectorOps.td"
#endif

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._vector_ops_gen import *

View File

@ -0,0 +1,26 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.vector as vector
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testPrintOp
@run
def testPrintOp():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
def print_vector(arg):
return vector.PrintOp(arg)
# CHECK-LABEL: func @print_vector(
# CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
# CHECK: vector.print %[[ARG]] : vector<12x5xf32>
# CHECK: return
# CHECK: }
print(module)