forked from OSchip/llvm-project
Move edsc python tests to Filecheck
-- PiperOrigin-RevId: 247479507
This commit is contained in:
parent
33449c3e6c
commit
b4c06416df
|
@ -12,17 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: $(dirname %s)/test_edsc %s | FileCheck %s
|
||||
"""Python2 and 3 test for the MLIR EDSC Python bindings"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import google_mlir.bindings.python.pybind as E
|
||||
import inspect
|
||||
|
||||
class EdscTest(unittest.TestCase):
|
||||
# Prints `str` prefixed by the current test function name so we can use it in
|
||||
# Filecheck label directives.
|
||||
# This is achieved by inspecting the stack and getting the parent name.
|
||||
def printWithCurrentFunctionName(str):
|
||||
print(inspect.stack()[1][3])
|
||||
print(str)
|
||||
|
||||
class EdscTest:
|
||||
|
||||
def setUp(self):
|
||||
self.module = E.MLIRModule()
|
||||
|
@ -32,34 +36,36 @@ class EdscTest(unittest.TestCase):
|
|||
self.indexType = self.module.make_index_type()
|
||||
|
||||
def testFunctionContext(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []):
|
||||
pass
|
||||
self.assertIsNotNone(self.module.get_function("foo"))
|
||||
printWithCurrentFunctionName(self.module.get_function("foo"))
|
||||
# CHECK-LABEL: testFunctionContext
|
||||
# CHECK: func @foo() {
|
||||
|
||||
def testMultipleFunctions(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []):
|
||||
pass
|
||||
with self.module.function_context("foo", [], []):
|
||||
E.constant_index(0)
|
||||
code = str(self.module)
|
||||
self.assertIn("func @foo()", code)
|
||||
self.assertIn(" %c0 = constant 0 : index", code)
|
||||
|
||||
with self.module.function_context("bar", [], []):
|
||||
E.constant_index(42)
|
||||
code = str(self.module)
|
||||
barPos = code.find("func @bar()")
|
||||
c42Pos = code.find("%c42 = constant 42 : index")
|
||||
self.assertNotEqual(barPos, -1)
|
||||
self.assertNotEqual(c42Pos, -1)
|
||||
self.assertGreater(c42Pos, barPos)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testMultipleFunctions
|
||||
# CHECK: func @foo()
|
||||
# CHECK: func @foo_0()
|
||||
# CHECK: %c0 = constant 0 : index
|
||||
|
||||
def testFunctionArgs(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [self.f32Type, self.f32Type],
|
||||
[self.indexType]) as fun:
|
||||
pass
|
||||
code = str(fun)
|
||||
self.assertIn("func @foo(%arg0: f32, %arg1: f32) -> index", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testFunctionArgs
|
||||
# CHECK: func @foo(%arg0: f32, %arg1: f32) -> index
|
||||
|
||||
def testLoopContext(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
lhs = E.constant_index(0)
|
||||
rhs = E.constant_index(42)
|
||||
|
@ -67,61 +73,49 @@ class EdscTest(unittest.TestCase):
|
|||
lhs + rhs + i
|
||||
with E.LoopContext(rhs, rhs + rhs, 2) as j:
|
||||
x = i + j
|
||||
code = str(fun)
|
||||
# TODO(zinenko,ntv): use FileCheck for these tests
|
||||
self.assertIn(' "affine.for"() ( {\n', code)
|
||||
self.assertIn(
|
||||
"{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}",
|
||||
code)
|
||||
self.assertIn(" ^bb1(%i0: index):", code)
|
||||
self.assertIn(' "affine.for"(%c42, %2) ( {\n', code)
|
||||
self.assertIn(
|
||||
"{lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> ()",
|
||||
code)
|
||||
self.assertIn(" ^bb2(%i1: index):", code)
|
||||
self.assertIn(
|
||||
' %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index',
|
||||
code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testLoopContext
|
||||
# CHECK: "affine.for"() (
|
||||
# CHECK: ^bb1(%i0: index):
|
||||
# CHECK: "affine.for"(%c42, %2) (
|
||||
# CHECK: ^bb2(%i1: index):
|
||||
# CHECK: "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index
|
||||
# CHECK: {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> ()
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}
|
||||
|
||||
def testLoopNestContext(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
lbs = [E.constant_index(i) for i in range(4)]
|
||||
ubs = [E.constant_index(10 * i + 5) for i in range(4)]
|
||||
with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
|
||||
i + j + k + l
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn(' "affine.for"() ( {\n', code)
|
||||
self.assertIn(" ^bb1(%i0: index):", code)
|
||||
self.assertIn(' "affine.for"() ( {\n', code)
|
||||
self.assertIn(" ^bb2(%i1: index):", code)
|
||||
self.assertIn(' "affine.for"() ( {\n', code)
|
||||
self.assertIn(" ^bb3(%i2: index):", code)
|
||||
self.assertIn(' "affine.for"() ( {\n', code)
|
||||
self.assertIn(" ^bb4(%i3: index):", code)
|
||||
self.assertIn(
|
||||
' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index',
|
||||
code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testLoopNestContext
|
||||
# CHECK: "affine.for"() (
|
||||
# CHECK: ^bb1(%i0: index):
|
||||
# CHECK: "affine.for"() (
|
||||
# CHECK: ^bb2(%i1: index):
|
||||
# CHECK: "affine.for"() (
|
||||
# CHECK: ^bb3(%i2: index):
|
||||
# CHECK: "affine.for"() (
|
||||
# CHECK: ^bb4(%i3: index):
|
||||
# CHECK: %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index
|
||||
|
||||
def testBlockContext(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
cst = E.constant_index(42)
|
||||
with E.BlockContext():
|
||||
cst + cst
|
||||
code = str(fun)
|
||||
# Find positions of instructions and make sure they are in the block we
|
||||
# put them by comparing those positions.
|
||||
# TODO(zinenko,ntv): this (and tests below) should use FileCheck instead.
|
||||
c42pos = code.find("%c42 = constant 42 : index")
|
||||
bb1pos = code.find("^bb1:")
|
||||
c84pos = code.find('%0 = "affine.apply"() {map: () -> (84)} : () -> index')
|
||||
self.assertNotEqual(c42pos, -1)
|
||||
self.assertNotEqual(bb1pos, -1)
|
||||
self.assertNotEqual(c84pos, -1)
|
||||
self.assertGreater(bb1pos, c42pos)
|
||||
self.assertLess(bb1pos, c84pos)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContext
|
||||
# CHECK: %c42 = constant 42 : index
|
||||
# CHECK: ^bb1:
|
||||
# CHECK: %0 = "affine.apply"() {map: () -> (84)} : () -> index
|
||||
|
||||
def testBlockContextAppend(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
E.constant_index(41)
|
||||
with E.BlockContext() as b:
|
||||
|
@ -130,25 +124,16 @@ class EdscTest(unittest.TestCase):
|
|||
E.constant_index(42)
|
||||
with E.BlockContext(E.appendTo(blk)):
|
||||
E.constant_index(1)
|
||||
code = str(fun)
|
||||
# Find positions of instructions and make sure they are in the block we put
|
||||
# them by comparing those positions.
|
||||
c41pos = code.find("%c41 = constant 41 : index")
|
||||
c42pos = code.find("%c42 = constant 42 : index")
|
||||
bb1pos = code.find("^bb1:")
|
||||
c0pos = code.find("%c0 = constant 0 : index")
|
||||
c1pos = code.find("%c1 = constant 1 : index")
|
||||
self.assertNotEqual(c41pos, -1)
|
||||
self.assertNotEqual(c42pos, -1)
|
||||
self.assertNotEqual(bb1pos, -1)
|
||||
self.assertNotEqual(c0pos, -1)
|
||||
self.assertNotEqual(c1pos, -1)
|
||||
self.assertGreater(bb1pos, c41pos)
|
||||
self.assertGreater(bb1pos, c42pos)
|
||||
self.assertLess(bb1pos, c0pos)
|
||||
self.assertLess(bb1pos, c1pos)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContextAppend
|
||||
# CHECK: %c41 = constant 41 : index
|
||||
# CHECK: %c42 = constant 42 : index
|
||||
# CHECK: ^bb1:
|
||||
# CHECK: %c0 = constant 0 : index
|
||||
# CHECK: %c1 = constant 1 : index
|
||||
|
||||
def testBlockContextStandalone(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
blk1 = E.BlockContext()
|
||||
blk2 = E.BlockContext()
|
||||
|
@ -161,81 +146,72 @@ class EdscTest(unittest.TestCase):
|
|||
with blk1:
|
||||
E.constant_index(1)
|
||||
E.constant_index(42)
|
||||
code = str(fun)
|
||||
# Find positions of instructions and make sure they are in the block we put
|
||||
# them by comparing those positions.
|
||||
c41pos = code.find(" %c41 = constant 41 : index")
|
||||
c42pos = code.find(" %c42 = constant 42 : index")
|
||||
bb1pos = code.find("^bb1:")
|
||||
c0pos = code.find(" %c0 = constant 0 : index")
|
||||
c1pos = code.find(" %c1 = constant 1 : index")
|
||||
bb2pos = code.find("^bb2:")
|
||||
c56pos = code.find(" %c56 = constant 56 : index")
|
||||
c57pos = code.find(" %c57 = constant 57 : index")
|
||||
self.assertNotEqual(c41pos, -1)
|
||||
self.assertNotEqual(c42pos, -1)
|
||||
self.assertNotEqual(bb1pos, -1)
|
||||
self.assertNotEqual(c0pos, -1)
|
||||
self.assertNotEqual(c1pos, -1)
|
||||
self.assertNotEqual(bb2pos, -1)
|
||||
self.assertNotEqual(c56pos, -1)
|
||||
self.assertNotEqual(c57pos, -1)
|
||||
self.assertGreater(bb1pos, c41pos)
|
||||
self.assertGreater(bb1pos, c42pos)
|
||||
self.assertLess(bb1pos, c0pos)
|
||||
self.assertLess(bb1pos, c1pos)
|
||||
self.assertGreater(bb2pos, c0pos)
|
||||
self.assertGreater(bb2pos, c1pos)
|
||||
self.assertGreater(bb2pos, bb1pos)
|
||||
self.assertLess(bb2pos, c56pos)
|
||||
self.assertLess(bb2pos, c57pos)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContextStandalone
|
||||
# CHECK: %c41 = constant 41 : index
|
||||
# CHECK: %c42 = constant 42 : index
|
||||
# CHECK: ^bb1:
|
||||
# CHECK: %c0 = constant 0 : index
|
||||
# CHECK: %c1 = constant 1 : index
|
||||
# CHECK: ^bb2:
|
||||
# CHECK: %c56 = constant 56 : index
|
||||
# CHECK: %c57 = constant 57 : index
|
||||
|
||||
def testBlockArguments(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
E.constant_index(42)
|
||||
with E.BlockContext([self.f32Type, self.f32Type]) as b:
|
||||
b.arg(0) + b.arg(1)
|
||||
code = str(fun)
|
||||
self.assertIn("%c42 = constant 42 : index", code)
|
||||
self.assertIn("^bb1(%0: f32, %1: f32):", code)
|
||||
self.assertIn(" %2 = addf %0, %1 : f32", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockArguments
|
||||
# CHECK: %c42 = constant 42 : index
|
||||
# CHECK: ^bb1(%0: f32, %1: f32):
|
||||
# CHECK: %2 = addf %0, %1 : f32
|
||||
|
||||
def testBr(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
with E.BlockContext() as b:
|
||||
blk = b
|
||||
E.ret()
|
||||
E.br(blk)
|
||||
code = str(fun)
|
||||
self.assertIn(" br ^bb1", code)
|
||||
self.assertIn("^bb1:", code)
|
||||
self.assertIn(" return", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBr
|
||||
# CHECK: br ^bb1
|
||||
# CHECK: ^bb1:
|
||||
# CHECK: return
|
||||
|
||||
def testBrDeclaration(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
blk = E.BlockContext()
|
||||
E.br(blk.handle())
|
||||
with blk:
|
||||
E.ret()
|
||||
code = str(fun)
|
||||
self.assertIn(" br ^bb1", code)
|
||||
self.assertIn("^bb1:", code)
|
||||
self.assertIn(" return", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBrDeclaration
|
||||
# CHECK: br ^bb1
|
||||
# CHECK: ^bb1:
|
||||
# CHECK: return
|
||||
|
||||
def testBrArgs(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [], []) as fun:
|
||||
# Create an infinite loop.
|
||||
with E.BlockContext([self.indexType, self.indexType]) as b:
|
||||
E.br(b, [b.arg(1), b.arg(0)])
|
||||
E.br(b, [E.constant_index(0), E.constant_index(1)])
|
||||
code = str(fun)
|
||||
self.assertIn(" %c0 = constant 0 : index", code)
|
||||
self.assertIn(" %c1 = constant 1 : index", code)
|
||||
self.assertIn(" br ^bb1(%c0, %c1 : index, index)", code)
|
||||
self.assertIn("^bb1(%0: index, %1: index):", code)
|
||||
self.assertIn(" br ^bb1(%1, %0 : index, index)", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBrArgs
|
||||
# CHECK: %c0 = constant 0 : index
|
||||
# CHECK: %c1 = constant 1 : index
|
||||
# CHECK: br ^bb1(%c0, %c1 : index, index)
|
||||
# CHECK: ^bb1(%0: index, %1: index):
|
||||
# CHECK: br ^bb1(%1, %0 : index, index)
|
||||
|
||||
def testCondBr(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [self.boolType], []) as fun:
|
||||
with E.BlockContext() as blk1:
|
||||
E.ret([])
|
||||
|
@ -243,86 +219,95 @@ class EdscTest(unittest.TestCase):
|
|||
E.ret([])
|
||||
cst = E.constant_index(0)
|
||||
E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testCondBr
|
||||
# CHECK: cond_br %arg0, ^bb1, ^bb2(%c0 : index)
|
||||
|
||||
def testRet(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [],
|
||||
[self.indexType, self.indexType]) as fun:
|
||||
c42 = E.constant_index(42)
|
||||
c0 = E.constant_index(0)
|
||||
E.ret([c42, c0])
|
||||
code = str(fun)
|
||||
self.assertIn(" %c42 = constant 42 : index", code)
|
||||
self.assertIn(" %c0 = constant 0 : index", code)
|
||||
self.assertIn(" return %c42, %c0 : index, index", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testRet
|
||||
# CHECK: %c42 = constant 42 : index
|
||||
# CHECK: %c0 = constant 0 : index
|
||||
# CHECK: return %c42, %c0 : index, index
|
||||
|
||||
def testSelectOp(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("foo", [self.boolType],
|
||||
[self.i32Type]) as fun:
|
||||
a = E.constant_int(42, 32)
|
||||
b = E.constant_int(0, 32)
|
||||
E.ret([E.select(fun.arg(0), a, b)])
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn("%0 = select %arg0, %c42_i32, %c0_i32 : i32", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testSelectOp
|
||||
# CHECK: %0 = select %arg0, %c42_i32, %c0_i32 : i32
|
||||
|
||||
def testCallOp(self):
|
||||
self.setUp()
|
||||
callee = self.module.declare_function("sqrtf", [self.f32Type],
|
||||
[self.f32Type])
|
||||
with self.module.function_context("call", [self.f32Type], []) as fun:
|
||||
funCst = E.constant_function(callee)
|
||||
funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
|
||||
|
||||
code = str(self.module)
|
||||
self.assertIn("func @sqrtf(f32) -> f32", code)
|
||||
self.assertIn("%f = constant @sqrtf : (f32) -> f32", code)
|
||||
self.assertIn("%0 = call_indirect %f(%arg0) : (f32) -> f32", code)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testCallOp
|
||||
# CHECK: func @sqrtf(f32) -> f32
|
||||
# CHECK: %f = constant @sqrtf : (f32) -> f32
|
||||
# CHECK: %0 = call_indirect %f(%arg0) : (f32) -> f32
|
||||
|
||||
def testBooleanOps(self):
|
||||
self.setUp()
|
||||
with self.module.function_context(
|
||||
"booleans", [self.boolType for _ in range(4)], []) as fun:
|
||||
i, j, k, l = (fun.arg(x) for x in range(4))
|
||||
stmt1 = (i < j) & (j >= k)
|
||||
stmt2 = ~(stmt1 | (k == l))
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn('%0 = cmpi "slt", %arg0, %arg1 : i1', code)
|
||||
self.assertIn('%1 = cmpi "sge", %arg1, %arg2 : i1', code)
|
||||
self.assertIn("%2 = muli %0, %1 : i1", code)
|
||||
self.assertIn('%3 = cmpi "eq", %arg2, %arg3 : i1', code)
|
||||
self.assertIn("%true = constant 1 : i1", code)
|
||||
self.assertIn("%4 = subi %true, %2 : i1", code)
|
||||
self.assertIn("%true_0 = constant 1 : i1", code)
|
||||
self.assertIn("%5 = subi %true_0, %3 : i1", code)
|
||||
self.assertIn("%6 = muli %4, %5 : i1", code)
|
||||
self.assertIn("%true_1 = constant 1 : i1", code)
|
||||
self.assertIn("%7 = subi %true_1, %6 : i1", code)
|
||||
self.assertIn("%true_2 = constant 1 : i1", code)
|
||||
self.assertIn("%8 = subi %true_2, %7 : i1", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBooleanOps
|
||||
# CHECK: %0 = cmpi "slt", %arg0, %arg1 : i1
|
||||
# CHECK: %1 = cmpi "sge", %arg1, %arg2 : i1
|
||||
# CHECK: %2 = muli %0, %1 : i1
|
||||
# CHECK: %3 = cmpi "eq", %arg2, %arg3 : i1
|
||||
# CHECK: %true = constant 1 : i1
|
||||
# CHECK: %4 = subi %true, %2 : i1
|
||||
# CHECK: %true_0 = constant 1 : i1
|
||||
# CHECK: %5 = subi %true_0, %3 : i1
|
||||
# CHECK: %6 = muli %4, %5 : i1
|
||||
# CHECK: %true_1 = constant 1 : i1
|
||||
# CHECK: %7 = subi %true_1, %6 : i1
|
||||
# CHECK: %true_2 = constant 1 : i1
|
||||
# CHECK: %8 = subi %true_2, %7 : i1
|
||||
|
||||
def testDivisions(self):
|
||||
self.setUp()
|
||||
with self.module.function_context(
|
||||
"division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
|
||||
# indices only support floor division
|
||||
fun.arg(0) // E.constant_index(42)
|
||||
# regular values only support regular division
|
||||
fun.arg(1) / fun.arg(2)
|
||||
|
||||
code = str(self.module)
|
||||
self.assertIn("floordiv 42", code)
|
||||
self.assertIn("divis %arg1, %arg2 : i32", code)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testDivisions
|
||||
# CHECK: floordiv 42
|
||||
# CHECK: divis %arg1, %arg2 : i32
|
||||
|
||||
def testCustom(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("custom", [self.indexType, self.f32Type],
|
||||
[]) as fun:
|
||||
E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
|
||||
code = str(fun)
|
||||
self.assertIn('%0 = "foo"(%arg0) : (index) -> f32', code)
|
||||
self.assertIn("%1 = addf %0, %arg1 : f32", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testCustom
|
||||
# CHECK: %0 = "foo"(%arg0) : (index) -> f32
|
||||
# CHECK: %1 = addf %0, %arg1 : f32
|
||||
|
||||
def testConstants(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("constants", [], []) as fun:
|
||||
E.constant_float(1.23, self.module.make_scalar_type("bf16"))
|
||||
E.constant_float(1.23, self.module.make_scalar_type("f16"))
|
||||
|
@ -335,20 +320,21 @@ class EdscTest(unittest.TestCase):
|
|||
E.constant_int(123, 64)
|
||||
E.constant_index(123)
|
||||
E.constant_function(fun)
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn("constant 1.230000e+00 : bf16", code)
|
||||
self.assertIn("constant 1.230470e+00 : f16", code)
|
||||
self.assertIn("constant 1.230000e+00 : f32", code)
|
||||
self.assertIn("constant 1.230000e+00 : f64", code)
|
||||
self.assertIn("constant 1 : i1", code)
|
||||
self.assertIn("constant 123 : i8", code)
|
||||
self.assertIn("constant 123 : i16", code)
|
||||
self.assertIn("constant 123 : i32", code)
|
||||
self.assertIn("constant 123 : index", code)
|
||||
self.assertIn("constant @constants : () -> ()", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testConstants
|
||||
# CHECK: constant 1.230000e+00 : bf16
|
||||
# CHECK: constant 1.230470e+00 : f16
|
||||
# CHECK: constant 1.230000e+00 : f32
|
||||
# CHECK: constant 1.230000e+00 : f64
|
||||
# CHECK: constant 1 : i1
|
||||
# CHECK: constant 123 : i8
|
||||
# CHECK: constant 123 : i16
|
||||
# CHECK: constant 123 : i32
|
||||
# CHECK: constant 123 : index
|
||||
# CHECK: constant @constants : () -> ()
|
||||
|
||||
def testIndexedValue(self):
|
||||
self.setUp()
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
|
||||
with self.module.function_context("indexed", [memrefType],
|
||||
[memrefType]) as fun:
|
||||
|
@ -359,21 +345,18 @@ class EdscTest(unittest.TestCase):
|
|||
[E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
|
||||
A.store([i, j], A.load([i, j]) + cst)
|
||||
E.ret([fun.arg(0)])
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn('"affine.for"()', code)
|
||||
self.assertIn(
|
||||
"{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}",
|
||||
code)
|
||||
self.assertIn('"affine.for"()', code)
|
||||
self.assertIn(
|
||||
"{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}",
|
||||
code)
|
||||
self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code)
|
||||
self.assertIn("%1 = addf %0, %cst : f32", code)
|
||||
self.assertIn("store %1, %arg0[%i0, %i1] : memref<10x42xf32>", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testIndexedValue
|
||||
# CHECK: "affine.for"()
|
||||
# CHECK: "affine.for"()
|
||||
# CHECK: %0 = load %arg0[%i0, %i1] : memref<10x42xf32>
|
||||
# CHECK: %1 = addf %0, %cst : f32
|
||||
# CHECK: store %1, %arg0[%i0, %i1] : memref<10x42xf32>
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}
|
||||
|
||||
def testMatrixMultiply(self):
|
||||
self.setUp()
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
|
||||
with self.module.function_context(
|
||||
"matmul", [memrefType, memrefType, memrefType], []) as fun:
|
||||
|
@ -386,65 +369,70 @@ class EdscTest(unittest.TestCase):
|
|||
k):
|
||||
C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
||||
E.ret([])
|
||||
|
||||
code = str(fun)
|
||||
self.assertIn('"affine.for"()', code)
|
||||
self.assertIn(
|
||||
"{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()",
|
||||
code)
|
||||
self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code)
|
||||
self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code)
|
||||
self.assertIn("%2 = mulf %0, %1 : f32", code)
|
||||
self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testMatrixMultiply
|
||||
# CHECK: "affine.for"()
|
||||
# CHECK: "affine.for"()
|
||||
# CHECK: "affine.for"()
|
||||
# CHECK-DAG: %0 = load %arg0[%i0, %i2] : memref<32x32xf32>
|
||||
# CHECK-DAG: %1 = load %arg1[%i2, %i1] : memref<32x32xf32>
|
||||
# CHECK: %2 = mulf %0, %1 : f32
|
||||
# CHECK: store %2, %arg2[%i0, %i1] : memref<32x32xf32>
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()
|
||||
# CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()
|
||||
|
||||
def testMLIRScalarTypes(self):
|
||||
self.setUp()
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("bf16")
|
||||
self.assertIn("bf16", t.__str__())
|
||||
t = module.make_scalar_type("f16")
|
||||
self.assertIn("f16", t.__str__())
|
||||
t = module.make_scalar_type("f32")
|
||||
self.assertIn("f32", t.__str__())
|
||||
t = module.make_scalar_type("f64")
|
||||
self.assertIn("f64", t.__str__())
|
||||
t = module.make_scalar_type("i", 1)
|
||||
self.assertIn("i1", t.__str__())
|
||||
t = module.make_scalar_type("i", 8)
|
||||
self.assertIn("i8", t.__str__())
|
||||
t = module.make_scalar_type("i", 32)
|
||||
self.assertIn("i32", t.__str__())
|
||||
t = module.make_scalar_type("i", 123)
|
||||
self.assertIn("i123", t.__str__())
|
||||
t = module.make_scalar_type("index")
|
||||
self.assertIn("index", t.__str__())
|
||||
printWithCurrentFunctionName(str(module.make_scalar_type("bf16")))
|
||||
print(str(module.make_scalar_type("f16")))
|
||||
print(str(module.make_scalar_type("f32")))
|
||||
print(str(module.make_scalar_type("f64")))
|
||||
print(str(module.make_scalar_type("i", 1)))
|
||||
print(str(module.make_scalar_type("i", 8)))
|
||||
print(str(module.make_scalar_type("i", 32)))
|
||||
print(str(module.make_scalar_type("i", 123)))
|
||||
print(str(module.make_scalar_type("index")))
|
||||
# CHECK-LABEL: testMLIRScalarTypes
|
||||
# CHECK: bf16
|
||||
# CHECK: f16
|
||||
# CHECK: f32
|
||||
# CHECK: f64
|
||||
# CHECK: i1
|
||||
# CHECK: i8
|
||||
# CHECK: i32
|
||||
# CHECK: i123
|
||||
# CHECK: index
|
||||
|
||||
def testMLIRFunctionCreation(self):
|
||||
self.setUp()
|
||||
module = E.MLIRModule()
|
||||
t = module.make_scalar_type("f32")
|
||||
self.assertIn("f32", t.__str__())
|
||||
m = module.make_memref_type(t, [3, 4, -1, 5])
|
||||
self.assertIn("memref<3x4x?x5xf32>", m.__str__())
|
||||
f = module.make_function("copy", [m, m], [])
|
||||
self.assertIn(
|
||||
"func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {",
|
||||
f.__str__())
|
||||
|
||||
f = module.make_function("sqrtf", [t], [t])
|
||||
self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__())
|
||||
printWithCurrentFunctionName(str(t))
|
||||
print(str(m))
|
||||
print(str(module.make_function("copy", [m, m], [])))
|
||||
print(str(module.make_function("sqrtf", [t], [t])))
|
||||
# CHECK-LABEL: testMLIRFunctionCreation
|
||||
# CHECK: f32
|
||||
# CHECK: memref<3x4x?x5xf32>
|
||||
# CHECK: func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {
|
||||
# CHECK: func @sqrtf(%arg0: f32) -> f32
|
||||
|
||||
def testFunctionDeclaration(self):
|
||||
module = E.MLIRModule()
|
||||
self.setUp()
|
||||
boolAttr = self.module.boolAttr(True)
|
||||
t = module.make_memref_type(self.f32Type, [10])
|
||||
t = self.module.make_memref_type(self.f32Type, [10])
|
||||
t_llvm_noalias = t({"llvm.noalias": boolAttr})
|
||||
t_readonly = t({"readonly": boolAttr})
|
||||
f = module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
|
||||
str = module.__str__()
|
||||
self.assertIn(
|
||||
"func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true})",
|
||||
str)
|
||||
f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionDeclaration
|
||||
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true})
|
||||
|
||||
def testMLIRBooleanCompilation(self):
|
||||
self.setUp()
|
||||
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
|
||||
with self.module.function_context("mkbooltensor", [m, m], []) as f:
|
||||
input = E.IndexedValue(f.arg(0))
|
||||
|
@ -457,23 +445,66 @@ class EdscTest(unittest.TestCase):
|
|||
b3 = b2 | (k < j)
|
||||
output.store([i], input.load([i]) & b3)
|
||||
E.ret([])
|
||||
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
self.module.compile()
|
||||
printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
|
||||
# CHECK-LABEL: testMLIRBooleanCompilation
|
||||
# CHECK: False
|
||||
|
||||
# Create 'addi' using the generic Op interface. We need an operation known
|
||||
# to the execution engine so that the engine can compile it.
|
||||
def testCustomOpCompilation(self):
|
||||
self.setUp()
|
||||
with self.module.function_context("adder", [self.i32Type], []) as f:
|
||||
c1 = E.op(
|
||||
"std.constant", [], [self.i32Type],
|
||||
value=self.module.integerAttr(self.i32Type, 42))
|
||||
E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
|
||||
E.ret([])
|
||||
|
||||
self.module.compile()
|
||||
self.assertNotEqual(self.module.get_engine_address(), 0)
|
||||
printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
|
||||
# CHECK-LABEL: testCustomOpCompilation
|
||||
# CHECK: False
|
||||
|
||||
# Until python 3.6 this cannot be used because the order in the dict is not the
|
||||
# order of method declaration.
|
||||
def runTests(edscTest):
|
||||
def isTest(attr):
|
||||
return inspect.ismethod(attr) and "__init" not in str(attr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
tests = filter(isTest, (getattr(edscTest, attr) for attr in dir(edscTest)))
|
||||
for test in tests:
|
||||
test()
|
||||
|
||||
# So instead one must list the functions in order of their Filecheck appearance.
|
||||
def main():
|
||||
edscTest = EdscTest()
|
||||
edscTest.testFunctionContext()
|
||||
edscTest.testMultipleFunctions()
|
||||
edscTest.testFunctionArgs()
|
||||
edscTest.testLoopContext()
|
||||
edscTest.testLoopNestContext()
|
||||
edscTest.testBlockContext()
|
||||
edscTest.testBlockContextAppend()
|
||||
edscTest.testBlockContextStandalone()
|
||||
edscTest.testBlockArguments()
|
||||
edscTest.testBr()
|
||||
edscTest.testBrDeclaration()
|
||||
edscTest.testBrArgs()
|
||||
edscTest.testCondBr()
|
||||
edscTest.testRet()
|
||||
edscTest.testSelectOp()
|
||||
edscTest.testCallOp()
|
||||
edscTest.testBooleanOps()
|
||||
edscTest.testDivisions()
|
||||
edscTest.testCustom()
|
||||
edscTest.testConstants()
|
||||
edscTest.testIndexedValue()
|
||||
edscTest.testMatrixMultiply()
|
||||
edscTest.testMLIRScalarTypes()
|
||||
edscTest.testMLIRFunctionCreation()
|
||||
edscTest.testFunctionDeclaration()
|
||||
edscTest.testMLIRBooleanCompilation()
|
||||
edscTest.testCustomOpCompilation()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue