forked from OSchip/llvm-project
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
import gc, sys
|
|
from mlir.ir import *
|
|
from mlir.passmanager import *
|
|
|
|
# Log everything to stderr and flush so that we have a unified stream to match
|
|
# errors/info emitted by MLIR to stderr.
|
|
def log(*args):
|
|
print(*args, file=sys.stderr)
|
|
sys.stderr.flush()
|
|
|
|
def run(f):
|
|
log("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
|
|
# Verify capsule interop.
|
|
# CHECK-LABEL: TEST: testCapsule
|
|
def testCapsule():
|
|
with Context():
|
|
pm = PassManager()
|
|
pm_capsule = pm._CAPIPtr
|
|
assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
|
|
pm._testing_release()
|
|
pm1 = PassManager._CAPICreate(pm_capsule)
|
|
assert pm1 is not None # And does not crash.
|
|
run(testCapsule)
|
|
|
|
|
|
# Verify successful round-trip.
|
|
# CHECK-LABEL: TEST: testParseSuccess
|
|
def testParseSuccess():
|
|
with Context():
|
|
# A first import is expected to fail because the pass isn't registered
|
|
# until we import mlir.transforms
|
|
try:
|
|
pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))")
|
|
# TODO: this error should be propagate to Python but the C API does not help right now.
|
|
# CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
|
|
except ValueError as e:
|
|
# CHECK: ValueError exception: invalid pass pipeline 'builtin.module(builtin.func(print-op-stats))'.
|
|
log("ValueError exception:", e)
|
|
else:
|
|
log("Exception not produced")
|
|
|
|
# This will register the pass and round-trip should be possible now.
|
|
import mlir.transforms
|
|
pm = PassManager.parse("builtin.module(builtin.func(print-op-stats))")
|
|
# CHECK: Roundtrip: builtin.module(builtin.func(print-op-stats))
|
|
log("Roundtrip: ", pm)
|
|
run(testParseSuccess)
|
|
|
|
# Verify failure on unregistered pass.
|
|
# CHECK-LABEL: TEST: testParseFail
|
|
def testParseFail():
|
|
with Context():
|
|
try:
|
|
pm = PassManager.parse("unknown-pass")
|
|
except ValueError as e:
|
|
# CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
|
|
log("ValueError exception:", e)
|
|
else:
|
|
log("Exception not produced")
|
|
run(testParseFail)
|
|
|
|
|
|
# Verify failure on incorrect level of nesting.
|
|
# CHECK-LABEL: TEST: testInvalidNesting
|
|
def testInvalidNesting():
|
|
with Context():
|
|
try:
|
|
pm = PassManager.parse("builtin.func(normalize-memrefs)")
|
|
except ValueError as e:
|
|
# CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'builtin.func', did you intend to nest?
|
|
# CHECK: ValueError exception: invalid pass pipeline 'builtin.func(normalize-memrefs)'.
|
|
log("ValueError exception:", e)
|
|
else:
|
|
log("Exception not produced")
|
|
run(testInvalidNesting)
|
|
|
|
|
|
# Verify that a pass manager can execute on IR
|
|
# CHECK-LABEL: TEST: testRun
|
|
def testRunPipeline():
|
|
with Context():
|
|
pm = PassManager.parse("print-op-stats")
|
|
module = Module.parse(r"""func @successfulParse() { return }""")
|
|
pm.run(module)
|
|
# CHECK: Operations encountered:
|
|
# CHECK: builtin.func , 1
|
|
# CHECK: builtin.module , 1
|
|
# CHECK: std.return , 1
|
|
run(testRunPipeline)
|