forked from OSchip/llvm-project
[mlir][python] Add custom constructor for memref load
The type can be inferred trivially, but it is currently done as string stitching between ODS and C++ and is not easily exposed to Python. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111712
This commit is contained in:
parent
cc83c2444f
commit
7fd6f40dbd
|
@ -0,0 +1,37 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class LoadOp:
|
||||
"""Specialization for the MemRef load operation."""
|
||||
|
||||
def __init__(self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView,
|
||||
Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates a memref load operation.
|
||||
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
memref_resolved = _get_op_result_or_value(memref)
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(
|
||||
indices)
|
||||
return_type = memref_resolved.type
|
||||
super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
|
|
@ -8,9 +8,11 @@ import mlir.dialects.memref as memref
|
|||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSubViewAccessors
|
||||
@run
|
||||
def testSubViewAccessors():
|
||||
ctx = Context()
|
||||
module = Module.parse(
|
||||
|
@ -52,4 +54,20 @@ def testSubViewAccessors():
|
|||
print(subview.strides[1])
|
||||
|
||||
|
||||
run(testSubViewAccessors)
|
||||
# CHECK-LABEL: TEST: testCustomBuidlers
|
||||
@run
|
||||
def testCustomBuidlers():
|
||||
with Context() as ctx, Location.unknown(ctx):
|
||||
module = Module.parse(r"""
|
||||
func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
|
||||
return
|
||||
}
|
||||
""")
|
||||
func = module.body.operations[0]
|
||||
func_body = func.regions[0].blocks[0]
|
||||
with InsertionPoint.at_block_terminator(func_body):
|
||||
memref.LoadOp(func.arguments[0], func.arguments[1:])
|
||||
|
||||
# CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
# CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
|
||||
print(module)
|
||||
|
|
Loading…
Reference in New Issue