[mlir][Python] Re-export cext sparse_tensor module to the public namespace.

* This was left out of the previous commit accidentally.

Differential Revision: https://reviews.llvm.org/D102183
This commit is contained in:
Stella Laurenzo 2021-05-10 17:42:24 +00:00
parent 08cf2776ac
commit f38633d1bb
3 changed files with 12 additions and 3 deletions

View File

@ -45,7 +45,10 @@ def _reexport_cext(cext_module_name, target_module_name):
"""
import sys
target_module = sys.modules[target_module_name]
source_module = getattr(_cext, cext_module_name)
submodule_names = cext_module_name.split(".")
source_module = _cext
for submodule_name in submodule_names:
source_module = getattr(source_module, submodule_name)
for attr_name in dir(source_module):
if not attr_name.startswith("__"):
setattr(target_module, attr_name, getattr(source_module, attr_name))

View File

@ -0,0 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._cext_loader import _reexport_cext
_reexport_cext("dialects.sparse_tensor", __name__)
del _reexport_cext

View File

@ -1,8 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
# TODO: Import this into the user-package vs the cext.
from _mlir.dialects import sparse_tensor as st
from mlir.dialects import sparse_tensor as st
def run(f):
print("\nTEST:", f.__name__)