Add a test example of calling a builtin function.

PiperOrigin-RevId: 235149430
This commit is contained in:
Brian Patton 2019-02-22 01:02:24 -08:00 committed by jpienaar
parent f0597cbf9f
commit d52e631359
1 changed files with 18 additions and 1 deletions

View File

@ -142,7 +142,7 @@ class EdscTest(unittest.TestCase):
with E.ContextManager(): with E.ContextManager():
module = E.MLIRModule() module = E.MLIRModule()
f32 = module.make_scalar_type("f32") f32 = module.make_scalar_type("f32")
func, arg = list(map(E.Expr, [E.Bindable(f32) for _ in range(2)])) func, arg = [E.Expr(E.Bindable(f32)) for _ in range(2)]
code = func(arg, result=f32) code = func(arg, result=f32)
self.assertIn("@$1($2)", str(code)) self.assertIn("@$1($2)", str(code))
@ -217,6 +217,23 @@ class EdscTest(unittest.TestCase):
self.assertIn("constant 123 : index", str) self.assertIn("constant 123 : index", str)
self.assertIn("constant @constants : () -> ()", str) self.assertIn("constant @constants : () -> ()", str)
def testMLIRBuiltinEmission(self):
module = E.MLIRModule()
m = module.make_memref_type(self.f32Type, [10]) # f32 tensor
f = module.make_function("call_builtin", [m, m], [])
with E.ContextManager():
emitter = E.MLIRFunctionEmitter(f)
input, output = list(map(E.Indexed, emitter.bind_function_arguments()))
fn = module.declare_function("sqrtf", [self.f32Type], [self.f32Type])
fn = emitter.bind_constant_function(fn)
zero = emitter.bind_constant_index(0)
emitter.emit_inplace(E.Block([
output.store([zero], fn(input.load([zero]), result=self.f32Type))
]))
str = f.__str__()
self.assertIn("%f = constant @sqrtf : (f32) -> f32", str)
self.assertIn("call_indirect %f(%0) : (f32) -> f32", str)
def testMLIRBooleanEmission(self): def testMLIRBooleanEmission(self):
m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
f = self.module.make_function("mkbooltensor", [m, m], []) f = self.module.make_function("mkbooltensor", [m, m], [])